侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

pytorch实现建立自己的数据集(以mnist为例)

Python  /  管理员 发布于 5年前   404

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import osimport matplotlibfrom keras.datasets import mnistimport numpy as npfrom torch.utils.data.dataset import Datasetfrom PIL import Imageimport scipy.miscroot_path = 'E:/coding_ex/pytorch/Alexnet/data/'base_path = 'baseset/'training_path = 'trainingset/'test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):  (x_train, y_train), (x_test, y_test) = mnist.load_data()  x_baseset = np.concatenate((x_train, x_test))  y_baseset = np.concatenate((y_train, y_test))  train_num = len(x_train)  test_num = len(x_test)    #baseset  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')  for i in range(train_num + test_num):    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name    file_label.write(str(y_baseset[i])+'\n') #label#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])  file_img.close()  file_label.close()    #trainingset  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')  for i in range(train_num):    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name    file_label.write(str(y_train[i])+'\n') #label#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])  file_img.close()  file_label.close()    #testset  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')  for i in range(test_num):    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name    file_label.write(str(y_test[i])+'\n') #label#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])  file_img.close()  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):    self.root_path = root_path    self.transform = transform    self.imagedata_path = imgdata_path    img_file = open((root_path + imgfile_path),'r')    self.image_name = [x.strip() for x in img_file]    img_file.close()    label_file = open((root_path + labelfile_path), 'r')    label = [int(x.strip()) for x in label_file]    label_file.close()    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的      def __getitem__(self, idx):    image = Image.open(str(self.image_name[idx]))    image = image.convert('RGB')    if self.transform is not None:      image = self.transform(image)    label = self.label[idx]    return image, label  def __len__(self):    return len(self.image_name)

定义完自己的类之后可以测试一下。

  LoadData(root_path, base_path, training_path, test_path)  training_imgfile = training_path + 'trainingset_img.txt'  training_labelfile = training_path + 'trainingset_label.txt'  training_imgdata = training_path + 'img/'  #实例化一个类  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


  • 上一条:
    使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
    下一条:
    解决Pycharm的项目目录突然消失的问题
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客