侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

TensorFlow实现Batch Normalization

linux  /  管理员 发布于 7年前   200

一、BN(Batch Normalization)算法

1. 对数据进行归一化处理的重要性

神经网络学习过程的本质就是学习数据分布,在训练数据与测试数据分布不同情况下,模型的泛化能力就大大降低;另一方面,若训练过程中每批batch的数据分布也各不相同,那么网络每批迭代学习过程也会出现较大波动,使之更难趋于收敛,降低训练收敛速度。对于深层网络,网络前几层的微小变化都会被网络累积放大,则训练数据的分布变化问题会被放大,更加影响训练速度。

2. BN算法的强大之处

1)为了加速梯度下降算法的训练,我们可以采取指数衰减学习率等方法在初期快速学习,后期缓慢进入全局最优区域。使用BN算法后,就可以直接选择比较大的学习率,且设置很大的学习率衰减速度,大大提高训练速度。即使选择了较小的学习率,也会比以前不使用BN情况下的收敛速度快。总结就是BN算法具有快速收敛的特性。

2)BN具有提高网络泛化能力的特性。采用BN算法后,就可以移除针对过拟合问题而设置的dropout和L2正则化项,或者采用更小的L2正则化参数。

3)BN本身是一个归一化网络层,则局部响应归一化层(Local Response Normalization,LRN层)则可不需要了(Alexnet网络中使用到)。

3. BN算法概述

BN算法提出了变换重构,引入了可学习参数γ、β,这就是算法的关键之处:

引入这两个参数后,我们的网络便可以学习恢复出原是网络所要学习的特征分布,BN层的钱箱传到过程如下:

其中m为batchsize。BatchNormalization中所有的操作都是平滑可导,这使得back propagation可以有效运行并学到相应的参数γ,β。需要注意的一点是Batch Normalization在training和testing时行为有所差别。Training时μβ和σβ由当前batch计算得出;在Testing时μβ和σβ应使用Training时保存的均值或类似的经过处理的值,而不是由当前batch计算。

二、TensorFlow相关函数

1.tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)

x是输入张量,axes是在哪个维度上求解, 即想要 normalize的维度, [0] 代表 batch 维度,如果是图像数据,可以传入 [0, 1, 2],相当于求[batch, height, width] 的均值/方差,注意不要加入channel 维度。该函数返回两个张量,均值mean和方差variance。

2.tf.identity(input, name=None)

返回与输入张量input形状和内容一致的张量。

3.tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon,name=None)

计算公式为scale(x - mean)/ variance + offset。

这些参数中,tf.nn.moments可得到均值mean和方差variance,offset和scale是可训练的,offset一般初始化为0,scale初始化为1,offset和scale的shape与mean相同,variance_epsilon参数设为一个很小的值如0.001。

三、TensorFlow代码实现

1. 完整代码

import tensorflow as tf import numpy as np import matplotlib.pyplot as plt  ACTIVITION = tf.nn.relu N_LAYERS = 7 # 总共7层隐藏层 N_HIDDEN_UNITS = 30 # 每层包含30个神经元  def fix_seed(seed=1): # 设置随机数种子   np.random.seed(seed)   tf.set_random_seed(seed)  def plot_his(inputs, inputs_norm): # 绘制直方图函数   for j, all_inputs in enumerate([inputs, inputs_norm]):     for i, input in enumerate(all_inputs):       plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1))       plt.cla()       if i == 0:         the_range = (-7, 10)       else:         the_range = (-1, 1)       plt.hist(input.ravel(), bins=15, range=the_range, color='#FF5733')       plt.yticks(())       if j == 1:         plt.xticks(the_range)       else:         plt.xticks(())       ax = plt.gca()       ax.spines['right'].set_color('none')       ax.spines['top'].set_color('none')     plt.title("%s normalizing" % ("Without" if j == 0 else "With"))   plt.draw()   plt.pause(0.01)  def built_net(xs, ys, norm): # 搭建网络函数   # 添加层   def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):     Weights = tf.Variable(tf.random_normal([in_size, out_size], mean=0.0, stddev=1.0))     biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)     Wx_plus_b = tf.matmul(inputs, Weights) + biases      if norm: # 判断是否是Batch Normalization层       # 计算均值和方差,axes参数0表示batch维度       fc_mean, fc_var = tf.nn.moments(Wx_plus_b, axes=[0])       scale = tf.Variable(tf.ones([out_size]))       shift = tf.Variable(tf.zeros([out_size]))       epsilon = 0.001        # 定义滑动平均模型对象       ema = tf.train.ExponentialMovingAverage(decay=0.5)        def mean_var_with_update():         ema_apply_op = ema.apply([fc_mean, fc_var])         with tf.control_dependencies([ema_apply_op]):           return tf.identity(fc_mean), tf.identity(fc_var)        mean, var = mean_var_with_update()        Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var,  shift, scale, epsilon)      if activation_function is None:       outputs = Wx_plus_b     else:       outputs = activation_function(Wx_plus_b)     return outputs    fix_seed(1)    if norm: # 为第一层进行BN     fc_mean, fc_var = tf.nn.moments(xs, axes=[0])     scale = tf.Variable(tf.ones([1]))     shift = tf.Variable(tf.zeros([1]))     epsilon = 0.001      ema = tf.train.ExponentialMovingAverage(decay=0.5)      def mean_var_with_update():       ema_apply_op = ema.apply([fc_mean, fc_var])       with tf.control_dependencies([ema_apply_op]):         return tf.identity(fc_mean), tf.identity(fc_var)      mean, var = mean_var_with_update()     xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon)    layers_inputs = [xs] # 记录每一层的输入    for l_n in range(N_LAYERS): # 依次添加7层     layer_input = layers_inputs[l_n]     in_size = layers_inputs[l_n].get_shape()[1].value      output = add_layer(layer_input, in_size, N_HIDDEN_UNITS, ACTIVITION, norm)     layers_inputs.append(output)    prediction = add_layer(layers_inputs[-1], 30, 1, activation_function=None)   cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),         reduction_indices=[1]))    train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost)   return [train_op, cost, layers_inputs]  fix_seed(1) x_data = np.linspace(-7, 10, 2500)[:, np.newaxis] np.random.shuffle(x_data) noise =np.random.normal(0, 8, x_data.shape) y_data = np.square(x_data) - 5 + noise  plt.scatter(x_data, y_data) plt.show()  xs = tf.placeholder(tf.float32, [None, 1]) ys = tf.placeholder(tf.float32, [None, 1])  train_op, cost, layers_inputs = built_net(xs, ys, norm=False) train_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True)  with tf.Session() as sess:   sess.run(tf.global_variables_initializer())    cost_his = []   cost_his_norm = []   record_step = 5    plt.ion()   plt.figure(figsize=(7, 3))   for i in range(250):     if i % 50 == 0:       all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm],   feed_dict={xs: x_data, ys: y_data})       plot_his(all_inputs, all_inputs_norm)      sess.run([train_op, train_op_norm],          feed_dict={xs: x_data[i*10:i*10+10], ys: y_data[i*10:i*10+10]})      if i % record_step == 0:       cost_his.append(sess.run(cost, feed_dict={xs: x_data, ys: y_data}))       cost_his_norm.append(sess.run(cost_norm,          feed_dict={xs: x_data, ys: y_data}))    plt.ioff()   plt.figure()   plt.plot(np.arange(len(cost_his))*record_step,        np.array(cost_his), label='Without BN')   # no norm   plt.plot(np.arange(len(cost_his))*record_step,        np.array(cost_his_norm), label='With BN')  # norm   plt.legend()   plt.show() 

2. 实验结果

输入数据分布:

批标准化BN效果对比:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    使用apidoc管理RESTful风格Flask项目接口文档方法
    下一条:
    关于Tensorflow中的tf.train.batch函数的使用
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在Linux系统中使用Iptables实现流量转发功能流程步骤(0个评论)
    • vim学习笔记-入门级需要了解的一些快捷键(0个评论)
    • 在centos7系统中实现分区并格式化挂载一块硬盘到/data目录流程步骤(0个评论)
    • 在Linux系统种查看某一个进程所占用的内存命令(0个评论)
    • Linux中grep命令中的10种高级用法浅析(0个评论)
    • 近期文章
    • 在windows10中升级go版本至1.24后LiteIDE的Ctrl+左击无法跳转问题解决方案(0个评论)
    • 智能合约Solidity学习CryptoZombie第四课:僵尸作战系统(0个评论)
    • 智能合约Solidity学习CryptoZombie第三课:组建僵尸军队(高级Solidity理论)(0个评论)
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-11
    • 2017-07
    • 2017-10
    • 2017-11
    • 2018-01
    • 2018-02
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2021-02
    • 2021-03
    • 2021-04
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-11
    • 2021-12
    • 2022-01
    • 2022-03
    • 2022-04
    • 2022-08
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-06
    • 2023-07
    • 2023-10
    • 2023-12
    • 2024-01
    • 2024-04
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客