侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

技术  /  管理员 发布于 7年前   175

1.创建tfrecord

tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:

tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为listtf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:

def get_tfrecords_example(feature, label): tfrecords_features = {} feat_shape = feature.shape tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape))) tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label)) return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:

tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') #创建tfrecord的writer,文件名为xxxexmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Exampleexmp_serial = exmp.SerializeToString()  #Example序列化tfrecord_wrt.write(exmp_serial)  #写入tfrecord文件tfrecord_wrt.close()  #写完后关闭tfrecord的writer

代码汇总:

import tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets mnist = read_data_sets("MNIST_data/", one_hot=True)#把数据写入Exampledef get_tfrecords_example(feature, label): tfrecords_features = {} feat_shape = feature.shape tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape))) tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label)) return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))#把所有数据写入tfrecord文件def make_tfrecord(data, outf_nm='mnist-train'): feats, labels = data outf_nm += '.tfrecord' tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm) ndatas = len(labels) for inx in range(ndatas): exmp = get_tfrecords_example(feats[inx], labels[inx]) exmp_serial = exmp.SerializeToString() tfrecord_wrt.write(exmp_serial) tfrecord_wrt.close() import randomnDatas = len(mnist.train.labels)inx_lst = range(nDatas)random.shuffle(inx_lst)random.shuffle(inx_lst)ntrains = int(0.85*nDatas) # make training setdata = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \ [mnist.train.labels[i] for i in inx_lst[:ntrains]])make_tfrecord(data, outf_nm='mnist-train') # make validation setdata = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \ [mnist.train.labels[i] for i in inx_lst[ntrains:]])make_tfrecord(data, outf_nm='mnist-val') # make test setdata = (mnist.test.images, mnist.test.labels)make_tfrecord(data, outf_nm='mnist-test')

2.tfrecord文件的使用:tf.data.TFRecordDataset

从tfrecord文件创建TFRecordDataset:

dataset = tf.data.TFRecordDataset('xxx.tfrecord')

解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:

feats = tf.parse_single_example(serial_exmp, features=data_dict)

其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([], tf.string)、tf.FixedLenFeature([], tf.int64)、tf.FixedLenFeature([], tf.float32),分别对应不同的数据类型,汇总即有:

def parse_exmp(serial_exmp):  #label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\ 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([x], tf.int64)})image = tf.decode_raw(feats['feature'], tf.float32)label = feats['label']shape = tf.cast(feats['shape'], tf.int32)return image, label, shape

解析tfrecord文件中的所有记录,使用dataset的map方法,如下:

dataset = dataset.map(parse_exmp)

map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:

dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)

解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:

iterator = dataset.make_one_shot_iterator()batch_image, batch_label, batch_shape = iterator.get_next()

要把不同dataset的数据feed进行模型,则需要先创建iterator handle,即iterator placeholder,如下:

handle = tf.placeholder(tf.string, shape=[])iterator = tf.data.Iterator.from_string_handle(handle, \ dataset_train.output_types, dataset_train.output_shapes)image, label, shape = iterator.get_next()

然后为各个dataset创建handle,以feed_dict传入placeholder,如下:

with tf.Session() as sess: handle_train, handle_val, handle_test = sess.run(\ [x.string_handle() for x in [iter_train, iter_val, iter_test]])    sess.run([loss, train_op], feed_dict={handle: handle_train}

汇总:

import tensorflow as tf train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']] def parse_exmp(serial_exmp): feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\ 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)}) image = tf.decode_raw(feats['feature'], tf.float32) label = feats['label'] shape = tf.cast(feats['shape'], tf.int32) return image, label, shape  def get_dataset(fname): dataset = tf.data.TFRecordDataset(fname) return dataset.map(parse_exmp) # use padded_batch method if padding needed epochs = 16batch_size = 50 # when batch_size can't be divided by nDatas, like 56, # there will be a batch data with nums less than batch_size # training datasetnDatasTrain = 46750dataset_train = get_dataset(train_f)dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)  # the latter means that there will be a batch data with nums less than batch_size for each epoch  # if when batch_size can't be divided by nDatas.nBatchs = nDatasTrain*epochs//batch_size # evalation datasetnDatasVal = 8250dataset_val = get_dataset(val_f)dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2) # test datasetnDatasTest = 10000dataset_test = get_dataset(test_f)dataset_test = dataset_test.batch(nDatasTest) # make dataset iteratoriter_train = dataset_train.make_one_shot_iterator()iter_val  = dataset_val.make_one_shot_iterator()iter_test  = dataset_test.make_one_shot_iterator() # make feedable iteratorhandle = tf.placeholder(tf.string, shape=[])iterator = tf.data.Iterator.from_string_handle(handle, \ dataset_train.output_types, dataset_train.output_shapes)x, y_, _ = iterator.get_next()train_op, loss, eval_op = model(x, y_)init = tf.initialize_all_variables() # summarylogdir = './logs/m4d2a'def summary_op(datapart='train'): tf.summary.scalar(datapart + '-loss', loss) tf.summary.scalar(datapart + '-eval', eval_op) return tf.summary.merge_all() summary_op_train = summary_op()summary_op_test = summary_op('val') with tf.Session() as sess: sess.run(init) handle_train, handle_val, handle_test = sess.run(\ [x.string_handle() for x in [iter_train, iter_val, iter_test]])    _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \  feed_dict={handle: handle_train, keep_prob: 0.5} )    cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_test], \  feed_dict={handle: handle_val, keep_prob: 1.0})

3.mnist实验

import tensorflow as tf train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']] def parse_exmp(serial_exmp): feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\ 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)}) image = tf.decode_raw(feats['feature'], tf.float32) label = feats['label'] shape = tf.cast(feats['shape'], tf.int32) return image, label, shape  def get_dataset(fname): dataset = tf.data.TFRecordDataset(fname) return dataset.map(parse_exmp) # use padded_batch method if padding needed epochs = 16batch_size = 50 # when batch_size can't be divided by nDatas, like 56, # there will be a batch data with nums less than batch_size # training datasetnDatasTrain = 46750dataset_train = get_dataset(train_f)dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)  # the latter means that there will be a batch data with nums less than batch_size for each epoch  # if when batch_size can't be divided by nDatas.nBatchs = nDatasTrain*epochs//batch_size # evalation datasetnDatasVal = 8250dataset_val = get_dataset(val_f)dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2) # test datasetnDatasTest = 10000dataset_test = get_dataset(test_f)dataset_test = dataset_test.batch(nDatasTest) # make dataset iteratoriter_train = dataset_train.make_one_shot_iterator()iter_val  = dataset_val.make_one_shot_iterator()iter_test  = dataset_test.make_one_shot_iterator() # make feedable iterator, i.e. iterator placeholderhandle = tf.placeholder(tf.string, shape=[])iterator = tf.data.Iterator.from_string_handle(handle, \ dataset_train.output_types, dataset_train.output_shapes)x, y_, _ = iterator.get_next() # cnnx_image = tf.reshape(x, [-1,28,28,1])w_init = tf.truncated_normal_initializer(stddev=0.1, seed=9)b_init = tf.constant_initializer(0.1)cnn1 = tf.layers.conv2d(x_image, 32, (5,5), padding='same', activation=tf.nn.relu, \ kernel_initializer=w_init, bias_initializer=b_init)mxpl1 = tf.layers.max_pooling2d(cnn1, 2, strides=2, padding='same')cnn2 = tf.layers.conv2d(mxpl1, 64, (5,5), padding='same', activation=tf.nn.relu, \ kernel_initializer=w_init, bias_initializer=b_init)mxpl2 = tf.layers.max_pooling2d(cnn2, 2, strides=2, padding='same')mxpl2_flat = tf.reshape(mxpl2, [-1,7*7*64])fc1 = tf.layers.dense(mxpl2_flat, 1024, activation=tf.nn.relu, \ kernel_initializer=w_init, bias_initializer=b_init)keep_prob = tf.placeholder('float')fc1_drop = tf.nn.dropout(fc1, keep_prob)logits = tf.layers.dense(fc1_drop, 10, kernel_initializer=w_init, bias_initializer=b_init) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))optmz = tf.train.AdamOptimizer(1e-4)train_op = optmz.minimize(loss) def get_eval_op(logits, labels): corr_prd = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1)) return tf.reduce_mean(tf.cast(corr_prd, 'float'))eval_op = get_eval_op(logits, y_) init = tf.initialize_all_variables() # summarylogdir = './logs/m4d2a'def summary_op(datapart='train'): tf.summary.scalar(datapart + '-loss', loss) tf.summary.scalar(datapart + '-eval', eval_op) return tf.summary.merge_all() summary_op_train = summary_op()summary_op_val = summary_op('val') # whether to restore or notckpts_dir = 'ckpts/'ckpt_nm = 'cnn-ckpt'saver = tf.train.Saver(max_to_keep=50) # defaults to save all variables, using dict {'x':x,...} to save specified ones.restore_step = ''start_step = 0train_steps = nBatchsbest_loss = 1e6best_step = 0 # import os# os.environ["CUDA_VISIBLE_DEVICES"] = "0"# config = tf.ConfigProto() # config.gpu_options.per_process_gpu_memory_fraction = 0.9# config.gpu_options.allow_growth=True # allocate when needed# with tf.Session(config=config) as sess:with tf.Session() as sess: sess.run(init) handle_train, handle_val, handle_test = sess.run(\ [x.string_handle() for x in [iter_train, iter_val, iter_test]]) if restore_step: ckpt = tf.train.get_checkpoint_state(ckpts_dir) if ckpt and ckpt.model_checkpoint_path: # ckpt.model_checkpoint_path means the latest ckpt  if restore_step == 'latest':  ckpt_f = tf.train.latest_checkpoint(ckpts_dir)  start_step = int(ckpt_f.split('-')[-1]) + 1  else:  ckpt_f = ckpts_dir+ckpt_nm+'-'+restore_step  print('loading wgt file: '+ ckpt_f)  saver.restore(sess, ckpt_f)  summary_wrt = tf.summary.FileWriter(logdir,sess.graph) if restore_step in ['', 'latest']: for i in range(start_step, train_steps):  _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \   feed_dict={handle: handle_train, keep_prob: 0.5} )  # log to stdout and eval validation set  if i % 100 == 0 or i == train_steps-1:  saver.save(sess, ckpts_dir+ckpt_nm, global_step=i) # save variables  summary_wrt.add_summary(summary, global_step=i)  cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_val], \   feed_dict={handle: handle_val, keep_prob: 1.0})  if cur_val_loss < best_loss:   best_loss = cur_val_loss   best_step = i  summary_wrt.add_summary(summary, global_step=i)  print 'step %5d: loss %.5f, acc %.5f --- loss val %0.5f, acc val %.5f'%(i, \   cur_loss, cur_train_eval, cur_val_loss, cur_val_eval)  # sess.run(init_train) with open(ckpts_dir+'best.step','w') as f:  f.write('best step is %d\n'%best_step) print 'best step is %d'%best_step # eval test set test_loss, test_eval = sess.run([loss, eval_op], feed_dict={handle: handle_test, keep_prob: 1.0}) print 'eval test: loss %.5f, acc %.5f'%(test_loss, test_eval)

实验结果:

以上这篇tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


  • 上一条:
    TensorFLow 不同大小图片的TFrecords存取实例
    下一条:
    tensorflow 变长序列存储实例
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 2024.07.09日OpenAI将终止对中国等国家和地区API服务(0个评论)
    • 2024/6/9最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 国外服务器实现api.openai.com反代nginx配置(0个评论)
    • 2024/4/28最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2017-07
    • 2017-08
    • 2017-09
    • 2018-01
    • 2018-07
    • 2018-08
    • 2018-09
    • 2018-12
    • 2019-01
    • 2019-02
    • 2019-03
    • 2019-04
    • 2019-05
    • 2019-06
    • 2019-07
    • 2019-08
    • 2019-09
    • 2019-10
    • 2019-11
    • 2019-12
    • 2020-01
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2020-07
    • 2020-08
    • 2020-09
    • 2020-10
    • 2020-11
    • 2021-04
    • 2021-05
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-12
    • 2022-01
    • 2022-02
    • 2022-03
    • 2022-04
    • 2022-05
    • 2022-06
    • 2022-07
    • 2022-08
    • 2022-09
    • 2022-10
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-04
    • 2023-05
    • 2023-06
    • 2023-07
    • 2023-08
    • 2023-09
    • 2023-10
    • 2023-12
    • 2024-02
    • 2024-04
    • 2024-05
    • 2024-06
    • 2025-02
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客