侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

TensorFlow搭建神经网络最佳实践

技术  /  管理员 发布于 7年前   156

一、TensorFLow完整样例

在MNIST数据集上,搭建一个简单神经网络结构,一个包含ReLU单元的非线性化处理的两层神经网络。在训练神经网络的时候,使用带指数衰减的学习率设置、使用正则化来避免过拟合、使用滑动平均模型来使得最终的模型更加健壮。

程序将计算神经网络前向传播的部分单独定义一个函数inference,训练部分定义一个train函数,再定义一个主函数main。

完整程序:

#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu May 25 08:56:30 2017  @author: marsjhao """  import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data  INPUT_NODE = 784 # 输入节点数 OUTPUT_NODE = 10 # 输出节点数 LAYER1_NODE = 500 # 隐含层节点数 BATCH_SIZE = 100 LEARNING_RETE_BASE = 0.8 # 基学习率 LEARNING_RETE_DECAY = 0.99 # 学习率的衰减率 REGULARIZATION_RATE = 0.0001 # 正则化项的权重系数 TRAINING_STEPS = 10000 # 迭代训练次数 MOVING_AVERAGE_DECAY = 0.99 # 滑动平均的衰减系数  # 传入神经网络的权重和偏置,计算神经网络前向传播的结果 def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):   # 判断是否传入ExponentialMovingAverage类对象   if avg_class == None:     layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)     return tf.matmul(layer1, weights2) + biases2   else:     layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1))        + avg_class.average(biases1))     return tf.matmul(layer1, avg_class.average(weights2))\  + avg_class.average(biases2)  # 神经网络模型的训练过程 def train(mnist):   x = tf.placeholder(tf.float32, [None,INPUT_NODE], name='x-input')   y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')    # 定义神经网络结构的参数   weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))   biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))   weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))   biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))    # 计算非滑动平均模型下的参数的前向传播的结果   y = inference(x, None, weights1, biases1, weights2, biases2)      global_step = tf.Variable(0, trainable=False) # 定义存储当前迭代训练轮数的变量    # 定义ExponentialMovingAverage类对象   variable_averages = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step) # 传入当前迭代轮数参数   # 定义对所有可训练变量trainable_variables进行更新滑动平均值的操作op   variables_averages_op = variable_averages.apply(tf.trainable_variables())    # 计算滑动模型下的参数的前向传播的结果   average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)    # 定义交叉熵损失值   cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(           logits=y, labels=tf.argmax(y_, 1))   cross_entropy_mean = tf.reduce_mean(cross_entropy)   # 定义L2正则化器并对weights1和weights2正则化   regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)   regularization = regularizer(weights1) + regularizer(weights2)   loss = cross_entropy_mean + regularization # 总损失值    # 定义指数衰减学习率   learning_rate = tf.train.exponential_decay(LEARNING_RETE_BASE, global_step,           mnist.train.num_examples / BATCH_SIZE, LEARNING_RETE_DECAY)   # 定义梯度下降操作op,global_step参数可实现自加1运算   train_step = tf.train.GradientDescentOptimizer(learning_rate)\  .minimize(loss, global_step=global_step)   # 组合两个操作op   train_op = tf.group(train_step, variables_averages_op)   '''''   # 与tf.group()等价的语句   with tf.control_dependencies([train_step, variables_averages_op]):     train_op = tf.no_op(name='train')   '''   # 定义准确率   # 在最终预测的时候,神经网络的输出采用的是经过滑动平均的前向传播计算结果   correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    # 初始化回话sess并开始迭代训练   with tf.Session() as sess:     sess.run(tf.global_variables_initializer())     # 验证集待喂入数据     validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}     # 测试集待喂入数据     test_feed = {x: mnist.test.images, y_: mnist.test.labels}     for i in range(TRAINING_STEPS):       if i % 1000 == 0:         validate_acc = sess.run(accuracy, feed_dict=validate_feed)         print('After %d training steps, validation accuracy'' using average model is %f' % (i, validate_acc))       xs, ys = mnist.train.next_batch(BATCH_SIZE)       sess.run(train_op, feed_dict={x: xs, y_:ys})      test_acc = sess.run(accuracy, feed_dict=test_feed)     print('After %d training steps, test accuracy'        ' using average model is %f' % (TRAINING_STEPS, test_acc))  # 主函数 def main(argv=None):   mnist = input_data.read_data_sets("MNIST_data", one_hot=True)   train(mnist)  # 当前的python文件是shell文件执行的入口文件,而非当做import的python module。 if __name__ == '__main__': # 在模块内部执行   tf.app.run() # 调用main函数并传入所需的参数list 

二、分析与改进设计

1. 程序分析改进

第一,计算前向传播的函数inference中需要将所有的变量以参数的形式传入函数,当神经网络结构变得更加复杂、参数更多的时候,程序的可读性将变得非常差。

第二,在程序退出时,训练好的模型就无法再利用,且大型神经网络的训练时间都比较长,在训练过程中需要每隔一段时间保存一次模型训练的中间结果,这样如果在训练过程中程序死机,死机前的最新的模型参数仍能保留,杜绝了时间和资源的浪费。

第三,将训练和测试分成两个独立的程序,将训练和测试都会用到的前向传播的过程抽象成单独的库函数。这样就保证了在训练和预测两个过程中所调用的前向传播计算程序是一致的。

2. 改进后程序设计

mnist_inference.py

该文件中定义了神经网络的前向传播过程,其中的多次用到的weights定义过程又单独定义成函数。

通过tf.get_variable函数来获取变量,在神经网络训练时创建这些变量,在测试时会通过保存的模型加载这些变量的取值,而且可以在变量加载时将滑动平均值重命名。所以可以直接通过同样的名字在训练时使用变量自身,在测试时使用变量的滑动平均值。

mnist_train.py

该程序给出了神经网络的完整训练过程。

mnist_eval.py

在滑动平均模型上做测试。

通过tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)获取最新模型的文件名,实际是获取checkpoint文件的所有内容。

三、TensorFlow最佳实践样例

mnist_inference.py

import tensorflow as tf  INPUT_NODE = 784 OUTPUT_NODE = 10 LAYER1_NODE = 500  def get_weight_variable(shape, regularizer):   weights = tf.get_variable("weights", shape,          initializer=tf.truncated_normal_initializer(stddev=0.1))   if regularizer != None:     # 将权重参数的正则化项加入至损失集合     tf.add_to_collection('losses', regularizer(weights))   return weights  def inference(input_tensor, regularizer):   with tf.variable_scope('layer1'):     weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)     biases = tf.get_variable("biases", [LAYER1_NODE],      initializer=tf.constant_initializer(0.0))     layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)    with tf.variable_scope('layer2'):     weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)     biases = tf.get_variable("biases", [OUTPUT_NODE],      initializer=tf.constant_initializer(0.0))     layer2 = tf.matmul(layer1, weights) + biases    return layer2 

mnist_train.py

import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_inference  BATCH_SIZE = 100 LEARNING_RATE_BASE = 0.8 LEARNING_RATE_DECAY = 0.99 REGULARIZATION_RATE = 0.0001 TRAINING_STEPS = 10000 MOVING_AVERAGE_DECAY = 0.99  MODEL_SAVE_PATH = "Model_Folder/" MODEL_NAME = "model.ckpt"  def train(mnist):   # 定义输入placeholder   x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')   y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')   # 定义正则化器及计算前向过程输出   regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)   y = mnist_inference.inference(x, regularizer)   # 定义当前训练轮数及滑动平均模型   global_step = tf.Variable(0, trainable=False)   variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,      global_step)   variables_averages_op = variable_averages.apply(tf.trainable_variables())   # 定义损失函数   cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,   labels=tf.argmax(y_, 1))   cross_entropy_mean = tf.reduce_mean(cross_entropy)   loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))   # 定义指数衰减学习率   learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step,           mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)   # 定义训练操作,包括模型训练及滑动模型操作   train_step = tf.train.GradientDescentOptimizer(learning_rate)\           .minimize(loss, global_step=global_step)   train_op = tf.group(train_step, variables_averages_op)   # 定义Saver类对象,保存模型,TensorFlow持久化类   saver = tf.train.Saver()    # 定义会话,启动训练过程   with tf.Session() as sess:     tf.global_variables_initializer().run()      for i in range(TRAINING_STEPS):       xs, ys = mnist.train.next_batch(BATCH_SIZE)       _, loss_value, step = sess.run([train_op, loss, global_step],           feed_dict={x: xs, y_: ys})       if i % 1000 == 0:         print("After %d training step(s), loss on training batch is %g."\ % (step, loss_value))         # save方法的global_step参数可以让每个被保存的模型的文件名末尾加上当前训练轮数         saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME),   global_step=global_step)  def main(argv=None):   mnist = input_data.read_data_sets("MNIST_data", one_hot=True)   train(mnist)  if __name__ == '__main__':   tf.app.run() 

mnist_eval.py

import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train  EVAL_INTERVAL_SECS = 10  def evaluate(mnist):   with tf.Graph().as_default() as g:     # 定义输入placeholder     x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE],   name='x-input')     y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE],   name='y-input')     # 定义feed字典     validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}     # 测试时不加参数正则化损失     y = mnist_inference.inference(x, None)     # 计算正确率     correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))     # 加载滑动平均模型下的参数值     variable_averages = tf.train.ExponentialMovingAverage(        mnist_train.MOVING_AVERAGE_DECAY)     saver = tf.train.Saver(variable_averages.variables_to_restore())      # 每隔EVAL_INTERVAL_SECS秒启动一次会话     while True:       with tf.Session() as sess:         ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)         if ckpt and ckpt.model_checkpoint_path:           saver.restore(sess, ckpt.model_checkpoint_path)           # 取checkpoint文件中的当前迭代轮数global_step           global_step = ckpt.model_checkpoint_path\        .split('/')[-1].split('-')[-1]           accuracy_score = sess.run(accuracy, feed_dict=validate_feed)           print("After %s training step(s), validation accuracy = %g"\  % (global_step, accuracy_score))          else:           print('No checkpoint file found')           return       time.sleep(EVAL_INTERVAL_SECS)  def main(argv=None):   mnist = input_data.read_data_sets("MNIST_data", one_hot=True)   evaluate(mnist)  if __name__ == '__main__':   tf.app.run() 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    TensorFlow实现卷积神经网络CNN
    下一条:
    特征脸(Eigenface)理论基础之PCA主成分分析法
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 2024.07.09日OpenAI将终止对中国等国家和地区API服务(0个评论)
    • 2024/6/9最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 国外服务器实现api.openai.com反代nginx配置(0个评论)
    • 2024/4/28最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2017-07
    • 2017-08
    • 2017-09
    • 2018-01
    • 2018-07
    • 2018-08
    • 2018-09
    • 2018-12
    • 2019-01
    • 2019-02
    • 2019-03
    • 2019-04
    • 2019-05
    • 2019-06
    • 2019-07
    • 2019-08
    • 2019-09
    • 2019-10
    • 2019-11
    • 2019-12
    • 2020-01
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2020-07
    • 2020-08
    • 2020-09
    • 2020-10
    • 2020-11
    • 2021-04
    • 2021-05
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-12
    • 2022-01
    • 2022-02
    • 2022-03
    • 2022-04
    • 2022-05
    • 2022-06
    • 2022-07
    • 2022-08
    • 2022-09
    • 2022-10
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-04
    • 2023-05
    • 2023-06
    • 2023-07
    • 2023-08
    • 2023-09
    • 2023-10
    • 2023-12
    • 2024-02
    • 2024-04
    • 2024-05
    • 2024-06
    • 2025-02
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客