侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

TensorFlow 滑动平均的示例代码

技术  /  管理员 发布于 7年前   161

滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。

1、滑动平均求解对象初始化

ema = tf.train.ExponentialMovingAverage(decay,num_updates) 

参数decay

`shadow_variable = decay * shadow_variable + (1 - decay) * variable`

参数num_updates

`min(decay, (1 + num_updates) / (10 + num_updates))`

2、添加/更新变量

添加目标变量,为之维护影子变量

注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量

ema.apply([var0, var1]) 

3、获取影子变量值

这一步不需要定义图中,从影子变量集合中提取目标值

sess.run(ema.average([var0, var1])) 

4、保存&载入影子变量

我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

保存影子变量

建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“

import tensorflow as tf if __name__ == "__main__":   v = tf.Variable(0.,name="v")   #设置滑动平均模型的系数   ema = tf.train.ExponentialMovingAverage(0.99)   #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量   op = ema.apply([v])   #获取变量v的名字   print(v.name)   #v:0   #创建一个保存模型的对象   save = tf.train.Saver()   sess = tf.Session()   #初始化所有变量   init = tf.initialize_all_variables()   sess.run(init)   #给变量v重新赋值   sess.run(tf.assign(v,10))   #应用平均滑动设置   sess.run(op)   #保存模型文件   save.save(sess,"./model.ckpt")   #输出变量v之前的值和使用滑动平均模型之后的值   print(sess.run([v,ema.average(v)]))   #[10.0, 0.099999905]  

载入影子变量并映射到变量

利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总

v = tf.Variable(1.,name="v") #定义模型对象 saver = tf.train.Saver({"v/ExponentialMovingAverage":v}) sess = tf.Session() saver.restore(sess,"./model.ckpt") print(sess.run(v)) #0.0999999  

这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。

使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

variables_to_restore函数的使用

v = tf.Variable(1.,name="v") #滑动模型的参数的大小并不会影响v的值 ema = tf.train.ExponentialMovingAverage(0.99) print(ema.variables_to_restore()) #{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} sess = tf.Session() saver = tf.train.Saver(ema.variables_to_restore()) saver.restore(sess,"./model.ckpt") print(sess.run(v)) #0.0999999 

variables_to_restore会识别网络中的变量,并自动生成影子变量名。

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

5、官方文档例子

官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例

 | Example usage when creating a training model: |  | ```python | # Create variables. | var0 = tf.Variable(...) | var1 = tf.Variable(...) | # ... use the variables to build a training model... | ... | # Create an op that applies the optimizer. This is what we usually | # would use as a training op. | opt_op = opt.minimize(my_loss, [var0, var1]) |  | # Create an ExponentialMovingAverage object | ema = tf.train.ExponentialMovingAverage(decay=0.9999) |  | with tf.control_dependencies([opt_op]): |   # Create the shadow variables, and add ops to maintain moving averages |   # of var0 and var1. This also creates an op that will update the moving |   # averages after each training step. This is what we will use in place |   # of the usual training op. |   training_op = ema.apply([var0, var1]) |  | ...train the model by running training_op... | ```

6、batch_normal的例子

和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):  with tf.variable_scope(scope):    # beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)    # gamma = tf.get_variable(name='gamma', shape=[n_out],    # initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)    batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')    ema = tf.train.ExponentialMovingAverage(decay=decay)     def mean_var_with_update():      ema_apply_op = ema.apply([batch_mean,batch_var])      with tf.control_dependencies([ema_apply_op]):        return tf.identity(batch_mean),tf.identity(batch_var)        # identity之后会把Variable转换为Tensor并入图中,        # 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制     mean,var = tf.cond(phase_train,  mean_var_with_update,  lambda: (ema.average(batch_mean),ema.average(batch_var)))     normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)  return normed 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    详解TensorFlow查看ckpt中变量的几种方法
    下一条:
    TensorFlow 模型载入方法汇总(小结)
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 2024.07.09日OpenAI将终止对中国等国家和地区API服务(0个评论)
    • 2024/6/9最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 国外服务器实现api.openai.com反代nginx配置(0个评论)
    • 2024/4/28最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2017-07
    • 2017-08
    • 2017-09
    • 2018-01
    • 2018-07
    • 2018-08
    • 2018-09
    • 2018-12
    • 2019-01
    • 2019-02
    • 2019-03
    • 2019-04
    • 2019-05
    • 2019-06
    • 2019-07
    • 2019-08
    • 2019-09
    • 2019-10
    • 2019-11
    • 2019-12
    • 2020-01
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2020-07
    • 2020-08
    • 2020-09
    • 2020-10
    • 2020-11
    • 2021-04
    • 2021-05
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-12
    • 2022-01
    • 2022-02
    • 2022-03
    • 2022-04
    • 2022-05
    • 2022-06
    • 2022-07
    • 2022-08
    • 2022-09
    • 2022-10
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-04
    • 2023-05
    • 2023-06
    • 2023-07
    • 2023-08
    • 2023-09
    • 2023-10
    • 2023-12
    • 2024-02
    • 2024-04
    • 2024-05
    • 2024-06
    • 2025-02
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客