某某潮流网,最新潮鞋资讯分享!

微信号:weixin888

Optimizer in PyTorch

时间:2024-06-18 20:58人气:编辑:佚名

torch.optim中实现了多种优化算法,目前已支持一些通用优化算法,同时提供接口来支持更加复杂的优化算法。

我们需要创建optimizer对象来保存当前状态并根据计算的梯度更新参数。创建optimizer时,需要提供需优化的模型参数,模型参数必须是iterable类型且顺序不变。optimizer中还可以指定一些优化选项,如:learning rate和weight decay等,优化选项可以用于全部模型参数,也可以为不同的模型参数提供不同的配置。

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

#最外层的配置作为default选项用于没设置该配置的参数组。
optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

所有的optimizer实现step()函数来更新参数,通常在梯度计算之后调用step()函数来更新参数。有些优化算法,如:LBFGS,需要多次执行step()函数,因此需要传入closure来允许重新计算模型,closure中应该清除梯度,计算损失并返回损失。

#在梯度计算之后来用梯度更新参数
for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

#传入closure
for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

torch.optim.Optimizer是所有优化算法的父类。

字段:

  • param_groups:存放模型参数;
  • defaults:存放默认优化参数;
  • state:存放优化器的当前优化状态,不同优化器的内容不同;

函数:

  • state_dict():返回优化器的状态,包括packed_state和param_groups;
  • load_state_dict():从参数中加载state_dict;
  • zero_grad():将每个参数的梯度设置为None或0;
  • step():为空函数,实现单步优化;
  • add_param_group():向param_groups中添加模型参数组;
class Optimizer(object):

    def __init__(self, params, defaults):
        self.defaults = defaults
        self.state = defaultdict(dict)
        self.param_groups = [] #dict的列表,[{'params':[Tensor|Dict]}]

        param_groups = list(params) #Tensor或dict的列表
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

    def __getstate__(self):
        return {
            'defaults': self.defaults,
            'state': self.state,
            'param_groups': self.param_groups,
        }

    def state_dict(self):
        ...
        return {
            'state': packed_state,
            'param_groups': param_groups,
        }

    def load_state_dict(self, state_dict):
        ...

    def zero_grad(self, set_to_none: bool = False):
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is not None:
                        if set_to_none:
                            p.grad = None
                        else:
                            p.grad.zero_() #简化写法

    def step(self, closure):
        raise NotImplementedError

    def add_param_group(self, param_group):
        ...


标签: 参数   优化  
相关资讯
热门频道

热门标签

官方微信官方微博百家号

网站简介 | 意见反馈 | 联系我们 | 法律声明 | 广告服务

Copyright © 2002-2022 天富平台-全球注册登录站 版权所有 备案号:粤ICP备xxxxxxx号

平台注册入口