侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

tensorflow入门:TFRecordDataset变长数据的batch读取详解

linux  /  管理员 发布于 7年前   240

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch( batch_size, padded_shapes, padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets mnist = read_data_sets("MNIST_data/", one_hot=True)  def get_tfrecords_example(feature, label): tfrecords_features = {} feat_shape = feature.shape tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature)) tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape))) tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label)) return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))  def make_tfrecord(data, outf_nm='mnist-train'): feats, labels = data outf_nm += '.tfrecord' tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm) ndatas = len(labels) print(feats[0].dtype, feats[0].shape, ndatas) assert len(labels[0]) > 1 for inx in range(ndatas): ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx]) exmp_serial = exmp.SerializeToString() tfrecord_wrt.write(exmp_serial) tfrecord_wrt.close() import randomnDatas = len(mnist.train.labels)inx_lst = range(nDatas)random.shuffle(inx_lst)random.shuffle(inx_lst)ntrains = int(0.85*nDatas) # make training setdata = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \ [mnist.train.labels[i] for i in inx_lst[:ntrains]])make_tfrecord(data, outf_nm='mnist-train') # make validation setdata = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \ [mnist.train.labels[i] for i in inx_lst[ntrains:]])make_tfrecord(data, outf_nm='mnist-val') # make test setdata = (mnist.test.images, mnist.test.labels)make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']] def parse_exmp(serial_exmp): feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\ 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)}) image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding shape = tf.cast(feats['shape'], tf.int32) return image, label, shape def get_dataset(fname): dataset = tf.data.TFRecordDataset(fname) return dataset.map(parse_exmp) # use padded_batch method if padding needed epochs = 16batch_size = 50 padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字# training datasetdataset_train = get_dataset(train_f)dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


  • 上一条:
    pytorch方法测试详解――归一化(BatchNorm2d)
    下一条:
    在tensorflow中实现去除不足一个batch的数据
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在Linux系统中使用Iptables实现流量转发功能流程步骤(0个评论)
    • vim学习笔记-入门级需要了解的一些快捷键(0个评论)
    • 在centos7系统中实现分区并格式化挂载一块硬盘到/data目录流程步骤(0个评论)
    • 在Linux系统种查看某一个进程所占用的内存命令(0个评论)
    • Linux中grep命令中的10种高级用法浅析(0个评论)
    • 近期文章
    • 在windows10中升级go版本至1.24后LiteIDE的Ctrl+左击无法跳转问题解决方案(0个评论)
    • 智能合约Solidity学习CryptoZombie第四课:僵尸作战系统(0个评论)
    • 智能合约Solidity学习CryptoZombie第三课:组建僵尸军队(高级Solidity理论)(0个评论)
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-11
    • 2017-07
    • 2017-10
    • 2017-11
    • 2018-01
    • 2018-02
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2021-02
    • 2021-03
    • 2021-04
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-11
    • 2021-12
    • 2022-01
    • 2022-03
    • 2022-04
    • 2022-08
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-06
    • 2023-07
    • 2023-10
    • 2023-12
    • 2024-01
    • 2024-04
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客