微信号:weixin888
torch.optim中实现了多种优化算法,目前已支持一些通用优化算法,同时提供接口来支持更加复杂的优化算法。
我们需要创建optimizer对象来保存当前状态并根据计算的梯度更新参数。创建optimizer时,需要提供需优化的模型参数,模型参数必须是iterable类型且顺序不变。optimizer中还可以指定一些优化选项,如:learning rate和weight decay等,优化选项可以用于全部模型参数,也可以为不同的模型参数提供不同的配置。
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
#最外层的配置作为default选项用于没设置该配置的参数组。
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
所有的optimizer实现step()函数来更新参数,通常在梯度计算之后调用step()函数来更新参数。有些优化算法,如:LBFGS,需要多次执行step()函数,因此需要传入closure来允许重新计算模型,closure中应该清除梯度,计算损失并返回损失。
#在梯度计算之后来用梯度更新参数
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
#传入closure
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
torch.optim.Optimizer是所有优化算法的父类。
字段:
函数:
class Optimizer(object):
def __init__(self, params, defaults):
self.defaults = defaults
self.state = defaultdict(dict)
self.param_groups = [] #dict的列表,[{'params':[Tensor|Dict]}]
param_groups = list(params) #Tensor或dict的列表
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
def __getstate__(self):
return {
'defaults': self.defaults,
'state': self.state,
'param_groups': self.param_groups,
}
def state_dict(self):
...
return {
'state': packed_state,
'param_groups': param_groups,
}
def load_state_dict(self, state_dict):
...
def zero_grad(self, set_to_none: bool = False):
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
p.grad.zero_() #简化写法
def step(self, closure):
raise NotImplementedError
def add_param_group(self, param_group):
...