侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

使用PyTorch实现MNIST手写体识别代码

Python  /  管理员 发布于 5年前   652

实验环境

win10 + anaconda + jupyter notebook

Pytorch1.1.0

Python3.7

gpu环境(可选)

MNIST数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。

导入包

import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformstorch.__version__

定义超参数

BATCH_SIZE=512EPOCHS=20 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

数据集

我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False

train_loader = torch.utils.data.DataLoader(    datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.1307,), (0.3081,))])),    batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(    datasets.MNIST('data', train=False, transform=transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.1307,), (0.3081,))])),    batch_size=BATCH_SIZE, shuffle=True)

定义网络

该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。

class ConvNet(nn.Module):  def __init__(self):    super().__init__()    self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24)     self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)    self.fc1 = nn.Linear(20*10*10,500)    self.fc2 = nn.Linear(500,10)  def forward(self,x):    in_size = x.size(0)    out = self.conv1(x)    out = F.relu(out)    out = F.max_pool2d(out, 2, 2)     out = self.conv2(out)    out = F.relu(out)    out = out.view(in_size,-1)    out = self.fc1(out)    out = F.relu(out)    out = self.fc2(out)    out = F.log_softmax(out,dim=1)    return out

实例化网络

model = ConvNet().to(DEVICE) # 将网络移动到gpu上optimizer = optim.Adam(model.parameters()) # 使用Adam优化器

定义训练函数

def train(model, device, train_loader, optimizer, epoch):  model.train()  for batch_idx, (data, target) in enumerate(train_loader):    data, target = data.to(device), target.to(device)    optimizer.zero_grad()    output = model(data)    loss = F.nll_loss(output, target)    loss.backward()    optimizer.step()    if(batch_idx+1)%30 == 0:       print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(        epoch, batch_idx * len(data), len(train_loader.dataset),        100. * batch_idx / len(train_loader), loss.item()))

定义测试函数

def test(model, device, test_loader):  model.eval()  test_loss = 0  correct = 0  with torch.no_grad():    for data, target in test_loader:      data, target = data.to(device), target.to(device)      output = model(data)      test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加      pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标      correct += pred.eq(target.view_as(pred)).sum().item()  test_loss /= len(test_loader.dataset)  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(    test_loss, correct, len(test_loader.dataset),    100. * correct / len(test_loader.dataset)))

开始训练

for epoch in range(1, EPOCHS + 1):  train(model, DEVICE, train_loader, optimizer, epoch)  test(model, DEVICE, test_loader)

实验结果

Train Epoch: 1 [14848/60000 (25%)]Loss: 0.375058Train Epoch: 1 [30208/60000 (50%)]Loss: 0.255248Train Epoch: 1 [45568/60000 (75%)]Loss: 0.128060Test set: Average loss: 0.0992, Accuracy: 9690/10000 (97%)Train Epoch: 2 [14848/60000 (25%)]Loss: 0.093066Train Epoch: 2 [30208/60000 (50%)]Loss: 0.087888Train Epoch: 2 [45568/60000 (75%)]Loss: 0.068078Test set: Average loss: 0.0599, Accuracy: 9816/10000 (98%)Train Epoch: 3 [14848/60000 (25%)]Loss: 0.043926Train Epoch: 3 [30208/60000 (50%)]Loss: 0.037321Train Epoch: 3 [45568/60000 (75%)]Loss: 0.068404Test set: Average loss: 0.0416, Accuracy: 9859/10000 (99%)Train Epoch: 4 [14848/60000 (25%)]Loss: 0.031654Train Epoch: 4 [30208/60000 (50%)]Loss: 0.041341Train Epoch: 4 [45568/60000 (75%)]Loss: 0.036493Test set: Average loss: 0.0361, Accuracy: 9873/10000 (99%)Train Epoch: 5 [14848/60000 (25%)]Loss: 0.027688Train Epoch: 5 [30208/60000 (50%)]Loss: 0.019488Train Epoch: 5 [45568/60000 (75%)]Loss: 0.018023Test set: Average loss: 0.0344, Accuracy: 9875/10000 (99%)Train Epoch: 6 [14848/60000 (25%)]Loss: 0.024212Train Epoch: 6 [30208/60000 (50%)]Loss: 0.018689Train Epoch: 6 [45568/60000 (75%)]Loss: 0.040412Test set: Average loss: 0.0350, Accuracy: 9879/10000 (99%)Train Epoch: 7 [14848/60000 (25%)]Loss: 0.030426Train Epoch: 7 [30208/60000 (50%)]Loss: 0.026939Train Epoch: 7 [45568/60000 (75%)]Loss: 0.010722Test set: Average loss: 0.0287, Accuracy: 9892/10000 (99%)Train Epoch: 8 [14848/60000 (25%)]Loss: 0.021109Train Epoch: 8 [30208/60000 (50%)]Loss: 0.034845Train Epoch: 8 [45568/60000 (75%)]Loss: 0.011223Test set: Average loss: 0.0299, Accuracy: 9904/10000 (99%)Train Epoch: 9 [14848/60000 (25%)]Loss: 0.011391Train Epoch: 9 [30208/60000 (50%)]Loss: 0.008091Train Epoch: 9 [45568/60000 (75%)]Loss: 0.039870Test set: Average loss: 0.0341, Accuracy: 9890/10000 (99%)Train Epoch: 10 [14848/60000 (25%)]Loss: 0.026813Train Epoch: 10 [30208/60000 (50%)]Loss: 0.011159Train Epoch: 10 [45568/60000 (75%)]Loss: 0.024884Test set: Average loss: 0.0286, Accuracy: 9901/10000 (99%)Train Epoch: 11 [14848/60000 (25%)]Loss: 0.006420Train Epoch: 11 [30208/60000 (50%)]Loss: 0.003641Train Epoch: 11 [45568/60000 (75%)]Loss: 0.003402Test set: Average loss: 0.0377, Accuracy: 9894/10000 (99%)Train Epoch: 12 [14848/60000 (25%)]Loss: 0.006866Train Epoch: 12 [30208/60000 (50%)]Loss: 0.012617Train Epoch: 12 [45568/60000 (75%)]Loss: 0.008548Test set: Average loss: 0.0311, Accuracy: 9908/10000 (99%)Train Epoch: 13 [14848/60000 (25%)]Loss: 0.010539Train Epoch: 13 [30208/60000 (50%)]Loss: 0.002952Train Epoch: 13 [45568/60000 (75%)]Loss: 0.002313Test set: Average loss: 0.0293, Accuracy: 9905/10000 (99%)Train Epoch: 14 [14848/60000 (25%)]Loss: 0.002100Train Epoch: 14 [30208/60000 (50%)]Loss: 0.000779Train Epoch: 14 [45568/60000 (75%)]Loss: 0.005952Test set: Average loss: 0.0335, Accuracy: 9897/10000 (99%)Train Epoch: 15 [14848/60000 (25%)]Loss: 0.006053Train Epoch: 15 [30208/60000 (50%)]Loss: 0.002559Train Epoch: 15 [45568/60000 (75%)]Loss: 0.002555Test set: Average loss: 0.0357, Accuracy: 9894/10000 (99%)Train Epoch: 16 [14848/60000 (25%)]Loss: 0.000895Train Epoch: 16 [30208/60000 (50%)]Loss: 0.004923Train Epoch: 16 [45568/60000 (75%)]Loss: 0.002339Test set: Average loss: 0.0400, Accuracy: 9893/10000 (99%)Train Epoch: 17 [14848/60000 (25%)]Loss: 0.004136Train Epoch: 17 [30208/60000 (50%)]Loss: 0.000927Train Epoch: 17 [45568/60000 (75%)]Loss: 0.002084Test set: Average loss: 0.0353, Accuracy: 9895/10000 (99%)Train Epoch: 18 [14848/60000 (25%)]Loss: 0.004508Train Epoch: 18 [30208/60000 (50%)]Loss: 0.001272Train Epoch: 18 [45568/60000 (75%)]Loss: 0.000543Test set: Average loss: 0.0380, Accuracy: 9894/10000 (99%)Train Epoch: 19 [14848/60000 (25%)]Loss: 0.001699Train Epoch: 19 [30208/60000 (50%)]Loss: 0.000661Train Epoch: 19 [45568/60000 (75%)]Loss: 0.000275Test set: Average loss: 0.0339, Accuracy: 9905/10000 (99%)Train Epoch: 20 [14848/60000 (25%)]Loss: 0.000441Train Epoch: 20 [30208/60000 (50%)]Loss: 0.000695Train Epoch: 20 [45568/60000 (75%)]Loss: 0.000467Test set: Average loss: 0.0396, Accuracy: 9894/10000 (99%)

总结

一个实际项目的工作流程:找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。

以上这篇使用PyTorch实现MNIST手写体识别代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


  • 上一条:
    PyTorch加载预训练模型实例(pretrained)
    下一条:
    Pytorch之finetune使用详解
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客