pytorch 利用lstm做mnist手写数字识别分类的实例
Python  /  管理员 发布于 5年前   565
代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。
# -*- coding: utf-8 -*-"""Created on Tue Oct 9 08:53:25 2018@author: www""" import syssys.path.append('..') import torchimport datetimefrom torch.autograd import Variablefrom torch import nnfrom torch.utils.data import DataLoader from torchvision import transforms as tfsfrom torchvision.datasets import MNIST #定义数据data_tf = tfs.Compose([ tfs.ToTensor(), tfs.Normalize([0.5], [0.5])])train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)test_set = MNIST('E:/data', train=False, transform=data_tf, download=True) train_data = DataLoader(train_set, 64, True, num_workers=4)test_data = DataLoader(test_set, 128, False, num_workers=4) #定义模型class rnn_classify(nn.Module): def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2): super(rnn_classify, self).__init__() self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果 def forward(self, x): #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28) x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28) x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28) out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature) out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature) out = self.classifier(out) #得到分类结果 return out net = rnn_classify()criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adadelta(net.parameters(), 1e-1) #定义训练过程def get_acc(output, label): total = output.shape[0] _, pred_label = output.max(1) num_correct = (pred_label == label).sum().item() return num_correct / total def train(net, train_data, valid_data, num_epochs, optimizer, criterion): if torch.cuda.is_available(): net = net.cuda() prev_time = datetime.datetime.now() for epoch in range(num_epochs): train_loss = 0 train_acc = 0 net = net.train() for im, label in train_data: if torch.cuda.is_available(): im = Variable(im.cuda()) # (bs, 3, h, w) label = Variable(label.cuda()) # (bs, h, w) else: im = Variable(im) label = Variable(label) # forward output = net(im) loss = criterion(output, label) # backward optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() train_acc += get_acc(output, label) cur_time = datetime.datetime.now() h, remainder = divmod((cur_time - prev_time).seconds, 3600) m, s = divmod(remainder, 60) time_str = "Time %02d:%02d:%02d" % (h, m, s) if valid_data is not None: valid_loss = 0 valid_acc = 0 net = net.eval() for im, label in valid_data: if torch.cuda.is_available(): im = Variable(im.cuda()) label = Variable(label.cuda()) else: im = Variable(im) label = Variable(label) output = net(im) loss = criterion(output, label) valid_loss += loss.item() valid_acc += get_acc(output, label) epoch_str = ( "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, " % (epoch, train_loss / len(train_data), train_acc / len(train_data), valid_loss / len(valid_data), valid_acc / len(valid_data))) else: epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " % (epoch, train_loss / len(train_data), train_acc / len(train_data))) prev_time = cur_time print(epoch_str + time_str) train(net, train_data, test_data, 10, optimizer, criterion)
以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
122 在
学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..123 在
Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..原梓番博客 在
在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..博主 在
佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..1111 在
佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
Copyright·© 2019 侯体宗版权所有·
粤ICP备20027696号