侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

详解PyTorch手写数字识别(MNIST数据集)

Python  /  管理员 发布于 5年前   331

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformsimport torchvisionfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderimport cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集train_dataset = datasets.MNIST(root='./num/',    train=True,    transform=transforms.ToTensor(),    download=True)# 下载测试集test_dataset = datasets.MNIST(root='./num/',   train=False,   transform=transforms.ToTensor(),   download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称# batch_size参数设置了每个包中的图片数据个数# 在装载的过程会将数据随机打乱顺序并进打包# 装载训练集train_loader = torch.utils.data.DataLoader(dataset=train_dataset,          batch_size=batch_size,          shuffle=True)# 装载测试集test_loader = torch.utils.data.DataLoader(dataset=test_dataset,         batch_size=batch_size,         shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1, 2, 0)std = [0.5, 0.5, 0.5]mean = [0.5, 0.5, 0.5]img = img * std + meanprint(labels)cv2.imshow('win', img)key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:


并且打印出了图片相对应的数字:

搭建神经网络

# 卷积层使用 torch.nn.Conv2d# 激活层使用 torch.nn.ReLU# 池化层使用 torch.nn.MaxPool2d# 全连接层使用 torch.nn.Linearclass LeNet(nn.Module):  def __init__(self):    super(LeNet, self).__init__()    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),      nn.MaxPool2d(2, 2))    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),      nn.MaxPool2d(2, 2))    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),     nn.BatchNorm1d(120), nn.ReLU())    self.fc2 = nn.Sequential(      nn.Linear(120, 84),      nn.BatchNorm1d(84),      nn.ReLU(),      nn.Linear(84, 10))    # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9  def forward(self, x):    x = self.conv1(x)    x = self.conv2(x)    x = x.view(x.size()[0], -1)    x = self.fc1(x)    x = self.fc2(x)    x = self.fc3(x)    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')batch_size = 64LR = 0.001net = LeNet().to(device)# 损失函数使用交叉熵criterion = nn.CrossEntropyLoss()# 优化函数使用 Adam 自适应优化算法optimizer = optim.Adam(  net.parameters(),  lr=LR,)epoch = 1if __name__ == '__main__':  for epoch in range(epoch):    sum_loss = 0.0    for i, data in enumerate(train_loader):      inputs, labels = data      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()      optimizer.zero_grad() #将梯度归零      outputs = net(inputs) #将数据传入网络进行前向运算      loss = criterion(outputs, labels) #得到损失函数      loss.backward() #反向传播      optimizer.step() #通过梯度做一步参数更新      # print(loss)      sum_loss += loss.item()      if i % 100 == 99:        print('[%d,%d] loss:%.03f' %           (epoch + 1, i + 1, sum_loss / 100))        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式  correct = 0  total = 0  for data_test in test_loader:    images, labels = data_test    images, labels = Variable(images).cuda(), Variable(labels).cuda()    output_test = net(images)    _, predicted = torch.max(output_test, 1)    total += labels.size(0)    correct += (predicted == labels).sum()  print("correct1: ", correct)  print("Test acc: {0}".format(correct.item() /     len(test_dataset)))

训练及测试的情况:


98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    关于PyTorch源码解读之torchvision.models
    下一条:
    pytorch 预训练层的使用方法
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客