侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

Pytorch 数据加载与数据预处理方式

Python  /  管理员 发布于 5年前   773

数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况。

torchvision.datasets中的数据集

torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法。

Dataset源码如上,可以看到其中包含了两个没有实现的子方法,之后所有的Dataet类都继承该类,并根据数据情况定制这两个子方法的具体实现。

因此当我们需要加载自己的数据集的时候也可以借鉴这种方法,只需要继承torch.utils.data.Dataset类并重写 init ,len,以及getitem这三个方法即可。这样组着的类可以直接作为参数传入到torch.util.data.DataLoader中去。

以CIFAR10为例 源码:

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
root (string) C Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.train (bool, optional) C If True, creates dataset from training set, otherwise creates from test set.transform (callable, optional) C A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCroptarget_transform (callable, optional) C A function/transform that takes in the target and transforms it.download (bool, optional) C If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

加载自己的数据集

对于torchvision.datasets中有两个不同的类,分别为DatasetFolder和ImageFolder,ImageFolder是继承自DatasetFolder。

下面我们通过源码来看一看folder文件中DatasetFolder和ImageFolder分别做了些什么

import torch.utils.data as datafrom PIL import Imageimport osimport os.pathdef has_file_allowed_extension(filename, extensions): //检查输入是否是规定的扩展名  """Checks if a file is an allowed extension.  Args:    filename (string): path to a file  Returns:    bool: True if the filename ends with a known image extension  """  filename_lower = filename.lower()  return any(filename_lower.endswith(ext) for ext in extensions)def find_classes(dir):  classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] //获取root目录下所有的文件夹名称  classes.sort()  class_to_idx = {classes[i]: i for i in range(len(classes))} //生成类别名称与类别id的对应Dictionary  return classes, class_to_idxdef make_dataset(dir, class_to_idx, extensions):  images = []  dir = os.path.expanduser(dir)// 将~和~user转化为用户目录,对参数中出现~进行处理  for target in sorted(os.listdir(dir)):    d = os.path.join(dir, target)    if not os.path.isdir(d):      continue    for root, _, fnames in sorted(os.walk(d)): //os.work包含三个部分,root代表该目录路径 _代表该路径下的文件夹名称集合,fnames代表该路径下的文件名称集合      for fname in sorted(fnames):        if has_file_allowed_extension(fname, extensions):          path = os.path.join(root, fname)          item = (path, class_to_idx[target])          images.append(item)  //生成(训练样本图像目录,训练样本所属类别)的元组  return images  //返回上述元组的列表class DatasetFolder(data.Dataset):  """A generic data loader where the samples are arranged in this way: ::    root/class_x/xxx.ext    root/class_x/xxy.ext    root/class_x/xxz.ext    root/class_y/123.ext    root/class_y/nsdf3.ext    root/class_y/asd932_.ext  Args:    root (string): Root directory path.    loader (callable): A function to load a sample given its path.    extensions (list[string]): A list of allowed extensions.    transform (callable, optional): A function/transform that takes in      a sample and returns a transformed version.      E.g, ``transforms.RandomCrop`` for images.    target_transform (callable, optional): A function/transform that takes      in the target and transforms it.   Attributes:    classes (list): List of the class names.    class_to_idx (dict): Dict with items (class_name, class_index).    samples (list): List of (sample path, class_index) tuples  """  def __init__(self, root, loader, extensions, transform=None, target_transform=None):    classes, class_to_idx = find_classes(root)    samples = make_dataset(root, class_to_idx, extensions)    if len(samples) == 0:      raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"    "Supported extensions are: " + ",".join(extensions)))    self.root = root    self.loader = loader    self.extensions = extensions    self.classes = classes    self.class_to_idx = class_to_idx    self.samples = samples    self.transform = transform    self.target_transform = target_transform  def __getitem__(self, index):    """    根据index获取sample 返回值为(sample,target)元组,同时如果该类输入参数中有transform和target_transform,torchvision.transforms类型的参数时,将获取的元组分别执行transform和target_transform中的数据转换方法。       Args:      index (int): Index    Returns:      tuple: (sample, target) where target is class_index of the target class.    """    path, target = self.samples[index]    sample = self.loader(path)    if self.transform is not None:      sample = self.transform(sample)    if self.target_transform is not None:      target = self.target_transform(target)    return sample, target  def __len__(self):    return len(self.samples)  def __repr__(self): //定义输出对象格式 其中和__str__的区别是__repr__无论是print输出还是直接输出对象自身 都是以定义的格式进行输出,而__str__ 只有在print输出的时候会是以定义的格式进行输出    fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'    fmt_str += '  Number of datapoints: {}\n'.format(self.__len__())    fmt_str += '  Root Location: {}\n'.format(self.root)    tmp = '  Transforms (if any): '    fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))    tmp = '  Target Transforms (if any): '    fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))    return fmt_strIMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']def pil_loader(path):  # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)  with open(path, 'rb') as f:    img = Image.open(f)    return img.convert('RGB')def accimage_loader(path):  import accimage  try:    return accimage.Image(path)  except IOError:    # Potentially a decoding problem, fall back to PIL.Image    return pil_loader(path)def default_loader(path):  from torchvision import get_image_backend  if get_image_backend() == 'accimage':    return accimage_loader(path)  else:    return pil_loader(path)class ImageFolder(DatasetFolder):   """A generic data loader where the images are arranged in this way: ::    root/dog/xxx.png    root/dog/xxy.png    root/dog/xxz.png    root/cat/123.png    root/cat/nsdf3.png    root/cat/asd932_.png  Args:    root (string): Root directory path.    transform (callable, optional): A function/transform that takes in an PIL image      and returns a transformed version. E.g, ``transforms.RandomCrop``    target_transform (callable, optional): A function/transform that takes in the      target and transforms it.    loader (callable, optional): A function to load an image given its path.   Attributes:    classes (list): List of the class names.    class_to_idx (dict): Dict with items (class_name, class_index).    imgs (list): List of (image path, class_index) tuples  """  def __init__(self, root, transform=None, target_transform=None,         loader=default_loader):    super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,         transform=transform,         target_transform=target_transform)    self.imgs = self.samples

如果自己所要加载的数据组织形式如下

root/dog/xxx.pngroot/dog/xxy.pngroot/dog/xxz.pngroot/cat/123.pngroot/cat/nsdf3.pngroot/cat/asd932_.png

即不同类别的训练数据分别存储在不同的文件夹中,这些文件夹都在root(即形如 D:/animals 或者 /usr/animals )路径下

class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)

参数如下:

root (string) C Root directory path.transform (callable, optional) C A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCroptarget_transform (callable, optional) C A function/transform that takes in the target and transforms it.loader C A function to load an image given its path. 就是上述源码中__getitem__(index)Parameters: index (int) C IndexReturns:  (sample, target) where target is class_index of the target class.Return type:  tuple

可以通过torchvision.datasets.ImageFolder进行加载

img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',          transform=transforms.Compose([transforms.Scale(256),transforms.CenterCrop(224),transforms.ToTensor()])          )print(len(img_data))data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)print(len(data_loader))

对于所有的训练样本都在一个文件夹中 同时有一个对应的txt文件每一行分别是对应图像的路径以及其所属的类别,可以参照上述class写出对应的加载类

def default_loader(path):  return Image.open(path).convert('RGB')class MyDataset(Dataset):  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):    fh = open(txt, 'r')    imgs = []    for line in fh:      line = line.strip('\n')      line = line.rstrip()      words = line.split()      imgs.append((words[0],int(words[1])))    self.imgs = imgs    self.transform = transform    self.target_transform = target_transform    self.loader = loader  def __getitem__(self, index):    fn, label = self.imgs[index]    img = self.loader(fn)    if self.transform is not None:      img = self.transform(img)    return img,label  def __len__(self):    return len(self.imgs)train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())data_loader = DataLoader(train_data, batch_size=100,shuffle=True)print(len(data_loader))

DataLoader解析

位于torch.util.data.DataLoader中 源代码

该接口的主要目的是将pytorch中已有的数据接口如torchvision.datasets.ImageFolder,或者自定义的数据读取接口转化按照

batch_size的大小封装为Tensor,即相当于在内置数据接口或者自定义数据接口的基础上增加一维,大小为batch_size的大小,

得到的数据在之后可以通过封装为Variable,作为模型的输出

_ _ init _ _中所需的参数如下

1. dataset torch.utils.data.Dataset类的子类,可以是torchvision.datasets.ImageFolder等内置类,也可是继承了torch.utils.data.Dataset的自定义类2. batch_size 每一个batch中包含的样本个数,默认是1 3. shuffle 一般在训练集中采用,默认是false,设置为true则每一个epoch都会将训练样本打乱4. sampler 训练样本选取策略,和shuffle是互斥的 如果 shuffle为true,该参数一定要为None5. batch_sampler BatchSampler 一次产生一个 batch 的 indices,和sampler以及shuffle互斥,一般使用默认的即可  上述Sampler的源代码地址如下[源代码](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py)6. num_workers 用于数据加载的线程数量 默认为0 即只有主线程用来加载数据7. collate_fn 用来聚合数据生成mini_batch

使用的时候一般为如下使用方法:

train_data=torch.utils.data.DataLoader(...) for i, (input, target) in enumerate(train_data): ...

循环取DataLoader中的数据会触发类中_ _ iter __方法,查看源代码可知 其中调用的方法为 return _DataLoaderIter(self),因此需要查看 DataLoaderIter 这一内部类

class DataLoaderIter(object):  "Iterates once over the DataLoader's dataset, as specified by the sampler"  def __init__(self, loader):    self.dataset = loader.dataset    self.collate_fn = loader.collate_fn    self.batch_sampler = loader.batch_sampler    self.num_workers = loader.num_workers    self.pin_memory = loader.pin_memory and torch.cuda.is_available()    self.timeout = loader.timeout    self.done_event = threading.Event()    self.sample_iter = iter(self.batch_sampler)    if self.num_workers > 0:      self.worker_init_fn = loader.worker_init_fn      self.index_queue = multiprocessing.SimpleQueue()      self.worker_result_queue = multiprocessing.SimpleQueue()      self.batches_outstanding = 0      self.worker_pids_set = False      self.shutdown = False      self.send_idx = 0      self.rcvd_idx = 0      self.reorder_dict = {}      base_seed = torch.LongTensor(1).random_()[0]      self.workers = [        multiprocessing.Process(          target=_worker_loop,          args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i))        for i in range(self.num_workers)]      if self.pin_memory or self.timeout > 0:        self.data_queue = queue.Queue()        self.worker_manager_thread = threading.Thread(          target=_worker_manager_loop,          args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, torch.cuda.current_device()))        self.worker_manager_thread.daemon = True        self.worker_manager_thread.start()      else:        self.data_queue = self.worker_result_queue      for w in self.workers:        w.daemon = True # ensure that the worker exits on process exit        w.start()      _update_worker_pids(id(self), tuple(w.pid for w in self.workers))      _set_SIGCHLD_handler()      self.worker_pids_set = True      # prime the prefetch loop      for _ in range(2 * self.num_workers):        self._put_indices()

以上这篇Pytorch 数据加载与数据预处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


  • 上一条:
    pytorch 归一化与反归一化实例
    下一条:
    pytorch 数据处理:定义自己的数据集合实例
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 智能合约Solidity学习CryptoZombie第三课:组建僵尸军队(高级Solidity理论)(0个评论)
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客