侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

TensorFlow实现Softmax回归模型

技术  /  管理员 发布于 7年前   140

一、概述及完整代码

对MNIST(MixedNational Institute of Standard and Technology database)这个非常简单的机器视觉数据集,Tensorflow为我们进行了方便的封装,可以直接加载MNIST数据成我们期望的格式.本程序使用Softmax Regression训练手写数字识别的分类模型.

先看完整代码:

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) print(mnist.train.images.shape, mnist.train.labels.shape) print(mnist.test.images.shape, mnist.test.labels.shape) print(mnist.validation.images.shape, mnist.validation.labels.shape)  #构建计算图 x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)  #在会话sess中启动图 sess = tf.InteractiveSession() #创建InteractiveSession对象 tf.global_variables_initializer().run() #全局参数初始化器 for i in range(1000):  batch_xs, batch_ys = mnist.train.next_batch(100)  train_step.run({x: batch_xs, y_: batch_ys})  #测试验证阶段 #沿着第1条轴方向取y和y_的最大值的索引并判断是否相等 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) #转换bool型tensor为float32型tensor并求平均即得到正确率 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})) 

二、详细解读

首先看一下使用TensorFlow进行算法设计训练的核心步骤

1.定义算法公式,也就是神经网络forward时的计算;

2.定义loss,选定优化器,并制定优化器优化loss;

3.在训练集上迭代训练算法模型;

4.在测试集或验证集上对训练得到的模型进行准确率评测.

首先创建一个Placeholder,即输入张量数据的地方,第一个参数是数据类型dtype,第二个参数是tensor的形状shape.接下来创建SoftmaxRegression模型中的weights(W)和biases(b)的Variable对象,不同于存储数据的tensor一旦使用掉就会消失,Variable在模型训练迭代中是持久存在的,并且在每轮迭代中被更新Variable初始化可以是常量或随机值.接下来实现模型算法y = softmax(Wx + b),TensorFlow语言只需要一行代码,tf.nn包含了大量神经网络的组件,头tf.matmul是矩阵乘法函数.TensorFlow将模型中的forward和backward的内容都自动实现,只要定义好loss,训练的时候会自动求导并进行梯度下降,完成对模型参数的自动学习.定义损失函数lossfunction来描述分类精度,对于多分类问题通常使用cross-entropy交叉熵.先定义一个placeholder输入真实的label,tf.reduce_sum和tf.reduce_mean的功能分别是求和和求平均.构造完损失函数cross-entropy后,再定义一个优化算法即可开始训练.我们采用随机梯度下降SGD,定义好后TensorFlow会自动添加许多运算操作来实现反向传播和梯度下降,而给我们提供的是一个封装好的优化器,只需要每轮迭代时feed数据给它就好.设置好学习率.

构造阶段完成后, 才能启动图. 启动图的第一步是创建一个 Session 对象或InteractiveSession对象, 如果无任何创建参数, 会话构造器将启动默认图.创建InteractiveSession对象会这个Session注册为默认的Session,之后的运算也默认跑在这个Session里面,不同Session之间的数据和运算应该是相互独立的.下一步使用TensorFlow的全局参数初始化器tf.global_variables_initializer病直接执行它的run方法(这个全局参数初始化器应该是1.0.0版本中的新特性,在之前0.10.0版本测试不通过).

至此,以上定义的所有公式其实只是Computation Graph,代码执行到这时,计算还没有实际发生,只有等调用run方法并feed数据时计算才真正执行.

随后一步,就可以开始迭代地执行训练操作train_step.这里每次都从训练集中随机抽取100条样本构成一个mini-batch,并feed给placeholder.

完成迭代训练后,就可以对模型的准确率进行验证.比较y和y_在各个测试样本中最大值所在的索引,然后转换为float32型tensor后求平均即可得到正确率.多次测试后得到在测试集上的正确率为92%左右.还是比较理想的结果.

三、其他补充

1.Sesssion类和InteractiveSession类

对于product =tf.matmul(matrix1, matrix2),调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回矩阵乘法 op 的输出.整个执行过程是自动化的, 会话负责传递op 所需的全部输入. op 通常是并发执行的.函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op)的执行.返回值 'result' 是一个 numpy的`ndarray`对象.

Session 对象在使用完后需要关闭以释放资源sess.close(). 除了显式调用 close 外, 也可以使用"with" 代码块 来自动完成关闭动作.

with tf.Session() as sess:  result = sess.run([product])  print result 

为了便于使用诸如 IPython 之类的 Python 交互环境, 可以使用InteractiveSession代替 Session 类, 使用 Tensor.eval()和 Operation.run()方法代替 Session.run(). 这样可以避免使用一个变量来持有会话.

# 进入一个交互式 TensorFlow 会话. import tensorflow as tf sess = tf.InteractiveSession() x = tf.Variable([1.0, 2.0]) a = tf.constant([3.0, 3.0]) # 使用初始化器 initializer op 的 run() 方法初始化 'x' x.initializer.run() # 增加一个减法 sub op, 从 'x' 减去 'a'. 运行减法 op, 输出结果 sub = tf.sub(x, a) print sub.eval() # ==> [-2. -1.] 

2.tf.reduce_sum

首先,tf.reduce_X一系列运算操作(operation)是实现对一个tensor各种减少维度的数学计算.

tf.reduce_sum(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:沿着给定维度reduction_indices的方向降低input_tensor的维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1, 1, 1] #   [1, 1, 1]] tf.reduce_sum(x) ==> 6 tf.reduce_sum(x, 0) ==> [2, 2, 2] tf.reduce_sum(x, 1) ==> [3, 3] tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]] tf.reduce_sum(x, [0, 1]) ==> 6 

3.tf.reduce_mean

tf.reduce_mean(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:将input_tensor沿着给定维度reduction_indices减少维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1., 1. ] #   [2., 2.]] tf.reduce_mean(x) ==> 1.5 tf.reduce_mean(x, 0) ==> [1.5, 1.5] tf.reduce_mean(x, 1) ==> [1., 2.] 

4.tf.argmax

tf.argmax(input, dimension, name=None)

运算功能:返回input在指定维度下的最大值的索引.返回类型为int64.

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    TensorFlow实现MLP多层感知机模型
    下一条:
    TensorFlow深度学习之卷积神经网络CNN
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 2024.07.09日OpenAI将终止对中国等国家和地区API服务(0个评论)
    • 2024/6/9最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 国外服务器实现api.openai.com反代nginx配置(0个评论)
    • 2024/4/28最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2017-07
    • 2017-08
    • 2017-09
    • 2018-01
    • 2018-07
    • 2018-08
    • 2018-09
    • 2018-12
    • 2019-01
    • 2019-02
    • 2019-03
    • 2019-04
    • 2019-05
    • 2019-06
    • 2019-07
    • 2019-08
    • 2019-09
    • 2019-10
    • 2019-11
    • 2019-12
    • 2020-01
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2020-07
    • 2020-08
    • 2020-09
    • 2020-10
    • 2020-11
    • 2021-04
    • 2021-05
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-12
    • 2022-01
    • 2022-02
    • 2022-03
    • 2022-04
    • 2022-05
    • 2022-06
    • 2022-07
    • 2022-08
    • 2022-09
    • 2022-10
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-04
    • 2023-05
    • 2023-06
    • 2023-07
    • 2023-08
    • 2023-09
    • 2023-10
    • 2023-12
    • 2024-02
    • 2024-04
    • 2024-05
    • 2024-06
    • 2025-02
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客