某某潮流网,最新潮鞋资讯分享!

微信号:weixin888

loss.backward() 和optimizer.step()的关系及灵活运用

时间:2024-05-26 09:19人气:编辑:佚名
loss.backward()
optimizer.step()

那么,这两个函数到底是怎么联系在一起的呢?

我们都知道,loss.backward()函数的作用是根据loss来计算网络参数的梯度,其对应的输入默认为网络的叶子节点,即数据集内的数据,叶子节点如下图所示:

在这里插入图片描述

同样的,该梯度信息也可以用函数torch.autograd.grad()计算得到

x = torch.tensor(2., requires_grad=True)
y = torch.tensor(3., requires_grad=True)

z = x * x * y
z.backward()
print(x.grad)
>>> tensor(12.)
x = torch.tensor(2., requires_grad=True)
y = torch.tensor(3., requires_grad=True)

z = x * x * y
x_grad = torch.autograd.grad(outputs=z, inputs=x)
print(x_grad[0])
>>> tensor(12.)

以上内容引自https://zhuanlan.zhihu.com/p/279758736

优化器的作用就是针对计算得到的参数梯度对网络参数进行更新,所以要想使得优化器起作用,主要需要两个东西:

  • 优化器需要知道当前的网络模型的参数空间
  • 优化器需要知道反向传播的梯度信息(即backward计算得到的信息)

观察一下SGD方法中step()方法的源码

def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = d_p.clone()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

		p.data.add_(-group['lr'], d_p)

        return loss

我们可以看到里面有如下的代码

for p in group['params']:
    if p.grad is None:
        continue
        d_p = p.grad.data

说明,step()函数确实是利用了计算得到的梯度信息,且该信息是与网络的参数绑定在一起的,所以optimizer函数在读入是先导入了网络参数模型’params’,然后通过一个.grad()函数就可以轻松的获取他的梯度信息。

我们想通过改变梯度信息来验证该关系的正确性,即是否可以通过一次梯度下降后,再通过一次梯度上升来得到初始化的参数

import torch
import torch.nn as nn

#  Check if we have a CUDA-capable device; if so, use it
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Will train on {}'.format(device))

#  为了让参数恢复成初始化状态,使用最简单的SGD优化器
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

#  定义模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.linear = nn.Linear(3,1)
        
    def forward(self, x):
        y = self.linear(x)
        return y
    
#  载入模型与输入,并打印此时的模型参数
x = (torch.rand(3)).to(device)
net = CNN().to(device)
print('the first output!')
for name, parameters in net.named_parameters():
    print(name, ':', parameters)
    
print('-------------------------------------------------------------------------------')    
#  做梯度下降
optimizer.zero_grad()
y = net(x)
loss = (1-y)**2

loss.backward()
optimizer.step()
#  打印梯度信息
for name, parameters in net.named_parameters():
    print(name, ':', parameters.grad)
#  经过第一次更新以后,打印网络参数
for name, parameters in net.named_parameters():
    print(name, ':', parameters)
    
print('-------------------------------------------------------------------------------')
#  我们直接将网络参数的梯度信息改为相反数来进行梯度上升
for name, parameters in net.named_parameters():
    parameters.grad *= -1
#  打印
for name, parameters in net.named_parameters():
    print('the second output!')
    print(name, ':', parameters.grad)

经过对比,我们发现最后的结果与我们的设想一样,网络参数恢复成初始化状态,因此可以证明optimizer.step()与loss.backward()之间的关系。

?探索loss.backward() 和optimizer.step()的关系并灵活运用_pytorch_Mr Sorry-DevPress官方社区 (csdn.net)

标签: me   in   梯度  
相关资讯
热门频道

热门标签

官方微信官方微博百家号

网站简介 | 意见反馈 | 联系我们 | 法律声明 | 广告服务

Copyright © 2002-2022 天富平台-全球注册登录站 版权所有 备案号:粤ICP备xxxxxxx号

平台注册入口