pytorch cnn 识别手写的字实现自建图片数据
Python  /  管理员 发布于 5年前   465
本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:
# library# standard libraryimport os # third-party libraryimport torchimport torch.nn as nnfrom torch.autograd import Variablefrom torch.utils.data import Dataset, DataLoaderimport torchvisionimport matplotlib.pyplot as pltfrom PIL import Imageimport numpy as np# torch.manual_seed(1) # reproducible # Hyper ParametersEPOCH = 1 # train the training data n times, to save time, we just train 1 epochBATCH_SIZE = 50LR = 0.001 # learning rate root = "./mnist/raw/" def default_loader(path): # return Image.open(path).convert('RGB') return Image.open(path) class MyDataset(Dataset): def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader fh.close() def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) img = Image.fromarray(np.array(img), mode='L') if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True) test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE) class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( # input shape (1, 28, 28) nn.Conv2d( in_channels=1, # input height out_channels=16, # n_filters kernel_size=5, # filter size stride=1, # filter movement/step padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1 ), # output shape (16, 28, 28) nn.ReLU(), # activation nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14) ) self.conv2 = nn.Sequential( # input shape (16, 14, 14) nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14) nn.ReLU(), # activation nn.MaxPool2d(2), # output shape (32, 7, 7) ) self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7) output = self.out(x) return output, x # return x for visualization cnn = CNN()print(cnn) # net architecture optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parametersloss_func = nn.CrossEntropyLoss()# the target label is not one-hotted # training and testingfor epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader b_x = Variable(x) # batch x b_y = Variable(y) # batch y output = cnn(b_x)[0] # cnn output loss = loss_func(output, b_y) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients if step % 50 == 0: cnn.eval() eval_loss = 0. eval_acc = 0. for i, (tx, ty) in enumerate(test_loader): t_x = Variable(tx) t_y = Variable(ty) output = cnn(t_x)[0] loss = loss_func(output, t_y) eval_loss += loss.data[0] pred = torch.max(output, 1)[1] num_correct = (pred == t_y).sum() eval_acc += float(num_correct.data[0]) acc_rate = eval_acc / float(len(test_data)) print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》
结果如下:
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。
122 在
学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..123 在
Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..原梓番博客 在
在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..博主 在
佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..1111 在
佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
Copyright·© 2019 侯体宗版权所有·
粤ICP备20027696号