侯体宗的博客
  • 首页
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

python使用tensorflow保存、加载和使用模型的方法

Python  /  管理员 发布于 7年前   207

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut1_save.py #Author: Wang  #Mail: [email protected] #Created Time:2017-08-30 11:04:25 ############################  import tensorflow as tf  # prepare to feed input, i.e. feed_dict and placeholders w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') b1 = tf.Variable(2.0, name = 'bias1') feed_dict = {w1:[10,3], w2:[5,5]}  # define a test operation that will be restored w3 = tf.add(w1, w2) # without name, w3 will not be stored w4 = tf.multiply(w3, b1, name = "op_to_restore")  #saver = tf.train.Saver() saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) sess = tf.Session() sess.run(tf.global_variables_initializer()) print sess.run(w4, feed_dict) #saver.save(sess, 'my_test_model', global_step = 100) saver.save(sess, 'my_test_model') #saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False) 

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut2_import.py #Author: Wang  #Mail: [email protected] #Created Time:2017-08-30 14:16:38 ############################  import tensorflow as tf sess = tf.Session() new_saver = tf.train.import_meta_graph('my_test_model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) print sess.run('w1:0') 

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut3_reuse.py #Author: Wang #Mail: [email protected] #Created Time:2017-08-30 14:33:35 ############################  import tensorflow as tf  sess = tf.Session()  # First, load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model.meta') saver.restore(sess, tf.train.latest_checkpoint('./'))  # Second, access and create placeholders variables and create feed_dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name('w1:0') w2 = graph.get_tensor_by_name('w2:0') feed_dict = {w1:[-1,1], w2:[4,6]}  # Access the op that want to run op_to_restore = graph.get_tensor_by_name('op_to_restore:0')  print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.] 

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf sess=tf.Session()   #First let's load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./')) # Now, let's access and create placeholders variables and # create feed-dict to feed new data  graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0}  #Now, access the op that you want to run.  op_to_restore = graph.get_tensor_by_name("op_to_restore:0")  #Add more to the current graph add_on_op = tf.multiply(op_to_restore,2)  print sess.run(add_on_op,feed_dict) #This will print 120. 

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... ...... saver = tf.train.import_meta_graph('vgg.meta') # Access the graph graph = tf.get_default_graph() ## Prepare the feed_dict for feeding data for fine-tuning   #Access the appropriate output for fine-tuning fc7= graph.get_tensor_by_name('fc7:0')  #use this if you only want to change gradients of the last layer fc7 = tf.stop_gradient(fc7) # It's an identity function fc7_shape= fc7.get_shape().as_list()  new_outputs=2 weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) output = tf.matmul(fc7, weights) + biases pred = tf.nn.softmax(output)  # Now, you run this with fine-tuning data in sess.run() 

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    python中使用xlrd读excel使用xlwt写excel的实例代码
    下一条:
    python通过elixir包操作mysql数据库实例代码
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 在windows10中升级go版本至1.24后LiteIDE的Ctrl+左击无法跳转问题解决方案(0个评论)
    • 智能合约Solidity学习CryptoZombie第四课:僵尸作战系统(0个评论)
    • 智能合约Solidity学习CryptoZombie第三课:组建僵尸军队(高级Solidity理论)(0个评论)
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(95个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客