pytorch制作自己的LMDB数据操作示例
数据库  /  管理员 发布于 5年前   536
本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:
记录下pytorch里如何使用lmdb的code,自用
code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签
import osimport lmdb # install lmdb by "pip install lmdb"import cv2import numpy as npfrom tqdm import tqdmimport sixfrom PIL import Imageimport scipy.io as siofrom tqdm import tqdmimport redef checkImageIsValid(imageBin): if imageBin is None: return False imageBuf = np.fromstring(imageBin, dtype=np.uint8) img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) imgH, imgW = img.shape[0], img.shape[1] if imgH * imgW == 0: return False return Truedef writeCache(env, cache): with env.begin(write=True) as txn: for k, v in cache.items(): txn.put(k.encode(), v)def _is_difficult(word): assert isinstance(word, str) return not re.match('^[\w]+$', word)def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): """ Create LMDB dataset for CRNN training. ARGS: outputPath : LMDB output path imagePathList : list of image path labelList : list of corresponding groundtruth texts lexiconList : (optional) list of lexicon lists checkValid : if true, check the validity of every image """ assert(len(imagePathList) == len(labelList)) nSamples = len(imagePathList) env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB cache = {} cnt = 1 for i in range(nSamples): imagePath = imagePathList[i] label = labelList[i] if len(label) == 0: continue if not os.path.exists(imagePath): print('%s does not exist' % imagePath) continue with open(imagePath, 'rb') as f: imageBin = f.read() if checkValid: if not checkImageIsValid(imageBin): print('%s is not a valid image' % imagePath) continue #数据库中都是二进制数据 imageKey = 'image-%09d' % cnt#9位数不足填零 labelKey = 'label-%09d' % cnt cache[imageKey] = imageBin cache[labelKey] = label.encode() if lexiconList: lexiconKey = 'lexicon-%09d' % cnt cache[lexiconKey] = ' '.join(lexiconList[i]) if cnt % 1000 == 0: writeCache(env, cache) cache = {} print('Written %d / %d' % (cnt, nSamples)) cnt += 1 nSamples = cnt-1 cache['num-samples'] = str(nSamples).encode() writeCache(env, cache) print('Created dataset with %d samples' % nSamples)def get_sample_list(txt_path:str): with open(txt_path,'r') as fr: jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())] txt_content_list=[] for jpg in jpg_list: label_path=jpg.replace('.jpg','.txt') with open(label_path,'r') as fr: try: str_tmp=fr.readline() except UnicodeDecodeError as e: print(label_path) raise(e) txt_content_list.append(str_tmp.strip()) return jpg_list,txt_content_listif __name__ == "__main__": txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt' lmdb_output_path = '/home/gpu-server/project/aster/dataset/train' imagePathList,labelList=get_sample_list(txt_path) createDataset(lmdb_output_path, imagePathList, labelList)
这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__
from __future__ import absolute_import# import sys# sys.path.append('./')import os# import moxing as moximport picklefrom tqdm import tqdmfrom PIL import Image, ImageFileimport numpy as npimport randomimport cv2import lmdbimport sysimport siximport torchfrom torch.utils import datafrom torch.utils.data import samplerfrom torchvision import transformsfrom lib.utils.labelmaps import get_vocabulary, labels2strsfrom lib.utils import to_numpyImageFile.LOAD_TRUNCATED_IMAGES = Truefrom config import get_argsglobal_args = get_args(sys.argv[1:])if global_args.run_on_remote: import moxing as mox #moxing是一个分布式的框架 跳过class LmdbDataset(data.Dataset): def __init__(self, root, voc_type, max_len, num_samples, transform=None): super(LmdbDataset, self).__init__() if global_args.run_on_remote: dataset_name = os.path.basename(root) data_cache_url = "/cache/%s" % dataset_name if not os.path.exists(data_cache_url): os.makedirs(data_cache_url) if mox.file.exists(root): mox.file.copy_parallel(root, data_cache_url) else: raise ValueError("%s not exists!" % root) self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True) else: self.env = lmdb.open(root, max_readers=32, readonly=True) assert self.env is not None, "cannot create lmdb from %s" % root self.txn = self.env.begin() self.voc_type = voc_type self.transform = transform self.max_len = max_len self.nSamples = int(self.txn.get(b"num-samples")) self.nSamples = min(self.nSamples, num_samples) assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS'] self.EOS = 'EOS' self.PADDING = 'PADDING' self.UNKNOWN = 'UNKNOWN' self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN) self.char2id = dict(zip(self.voc, range(len(self.voc)))) self.id2char = dict(zip(range(len(self.voc)), self.voc)) self.rec_num_classes = len(self.voc) self.lowercase = (voc_type == 'LOWERCASE') def __len__(self): return self.nSamples def __getitem__(self, index): assert index <= len(self), 'index range error' index += 1 img_key = b'image-%09d' % index imgbuf = self.txn.get(img_key) #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象 buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) try: img = Image.open(buf).convert('RGB') # img = Image.open(buf).convert('L') # img = img.convert('RGB') except IOError: print('Corrupted image for %d' % index) return self[index + 1] # reconition labels label_key = b'label-%09d' % index word = self.txn.get(label_key).decode() if self.lowercase: word = word.lower() ## fill with the padding token label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int) label_list = [] for char in word: if char in self.char2id: label_list.append(self.char2id[char]) else: ## add the unknown token print('{0} is out of vocabulary.'.format(char)) label_list.append(self.char2id[self.UNKNOWN]) ## add a stop token label_list = label_list + [self.char2id[self.EOS]] assert len(label_list) <= self.max_len label[:len(label_list)] = np.array(label_list) if len(label) <= 0: return self[index + 1] # label length label_len = len(label_list) if self.transform is not None: img = self.transform(img) return img, label, label_len
更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
122 在
学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..123 在
Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..原梓番博客 在
在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..博主 在
佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..1111 在
佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
Copyright·© 2019 侯体宗版权所有·
粤ICP备20027696号