侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

Pytorch 中retain_graph的用法详解

Python  /  管理员 发布于 5年前   616

用法分析

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?

############################    # (1) Update D network: maximize D(x)-1-D(G(z))    ###########################    real_img = Variable(target)    if torch.cuda.is_available():      real_img = real_img.cuda()    z = Variable(data)    if torch.cuda.is_available():      z = z.cuda()    fake_img = netG(z)    netD.zero_grad()    real_out = netD(real_img).mean()    fake_out = netD(fake_img).mean()    d_loss = 1 - real_out + fake_out    d_loss.backward(retain_graph=True) #####    optimizerD.step()    ############################    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss    ###########################    netG.zero_grad()    g_loss = generator_criterion(fake_out, fake_img, real_img)    g_loss.backward()    optimizerG.step()    fake_img = netG(z)    fake_out = netD(fake_img).mean()    g_loss = generator_criterion(fake_out, fake_img, real_img)    running_results['g_loss'] += g_loss.data[0] * batch_size    d_loss = 1 - real_out + fake_out    running_results['d_loss'] += d_loss.data[0] * batch_size    running_results['d_score'] += real_out.data[0] * batch_size    running_results['g_score'] += fake_out.data[0] * batch_size

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,

如下代码:

import torchy=x**2z=y*4output1=z.mean()output2=z.sum()output1.backward()output2.backward()

输出如下错误信息:

---------------------------------------------------------------------------RuntimeError   Traceback (most recent call last)<ipython-input-19-8ad6b0658906> in <module>()----> 1 output1.backward()   2 output2.backward()D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)   91         products. Defaults to ``False``.   92     """---> 93     torch.autograd.backward(self, gradient, retain_graph, create_graph)   94    95   def register_hook(self, hook):D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)   88   Variable._execution_engine.run_backward(   89     tensors, grad_tensors, retain_graph, create_graph,---> 90     allow_unreachable=True) # allow_unreachable flag   91    92 RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正确:

import torchy=x**2z=y*4output1=z.mean()output2=z.sum()output1.backward(retain_graph=True)output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backwardloss1.backward(retain_graph=True)loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环optimizer.step() # 更新参数

Variable 类源代码

class Variable(_C._VariableBase):   """  Attributes:    data: 任意类型的封装好的张量。    grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。    requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。    volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。    is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.    grad_fn: Gradient function graph trace.   Parameters:    data (any tensor class): 要包装的张量.    requires_grad (bool): bool型的标记值. **Keyword only.**    volatile (bool): bool型的标记值. **Keyword only.**  """   def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):    """计算关于当前图叶子变量的梯度,图使用链式法则导致分化    如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数    如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;    需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);    可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。    函数在叶子上累积梯度,调用前需要对该叶子进行清零。     Arguments:      grad_variables (Tensor, Variable or None):  变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。  可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。      retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。          在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。          默认值为create_graph的值。      create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。          默认为False,除非``gradient``是一个volatile变量。    """    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)    def register_hook(self, hook):    """Registers a backward hook.     每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None    不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。     函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。     Example:      >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)      >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient      >>> v.backward(torch.Tensor([1, 1, 1]))      >>> v.grad.data       2       2       2      [torch.FloatTensor of size 3]      >>> h.remove() # removes the hook    """    if self.volatile:      raise RuntimeError("cannot register a hook on a volatile variable")    if not self.requires_grad:      raise RuntimeError("cannot register a hook on a variable that "    "doesn't require gradient")    if self._backward_hooks is None:      self._backward_hooks = OrderedDict()      if self.grad_fn is not None:        self.grad_fn._register_hook_dict(self)    handle = hooks.RemovableHandle(self._backward_hooks)    self._backward_hooks[handle.id] = hook    return handle   def reinforce(self, reward):    """Registers a reward obtained as a result of a stochastic process.    区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。    Parameters:      reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。    """    if not isinstance(self.grad_fn, StochasticFunction):      raise RuntimeError("reinforce() can be only called on outputs "    "of stochastic functions")    self.grad_fn._reinforce(reward)   def detach(self):    """返回一个从当前图分离出来的心变量。    结果不需要梯度,如果输入是volatile,则输出也是volatile。     .. 注意::     返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。    """    result = NoGrad()(self) # this is needed, because it merges version counters    result._grad_fn = None    return result   def detach_(self):    """从创建它的图中分离出变量并作为该图的一个叶子"""    self._grad_fn = None    self.requires_grad = False   def retain_grad(self):    """Enables .grad attribute for non-leaf Variables."""    if self.grad_fn is None: # no-op for leaves      return    if not self.requires_grad:      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")    if hasattr(self, 'retains_grad'):      return    weak_self = weakref.ref(self)     def retain_grad_hook(grad):      var = weak_self()      if var is None:        return      if var._grad is None:        var._grad = grad.clone()      else:        var._grad = var._grad + grad     self.register_hook(retain_grad_hook)    self.retains_grad = True

以上这篇Pytorch 中retain_graph的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


  • 上一条:
    Pytorch中膨胀卷积的用法详解
    下一条:
    PyTorch中的Variable变量详解
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 智能合约Solidity学习CryptoZombie第三课:组建僵尸军队(高级Solidity理论)(0个评论)
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客