侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

简单易懂Pytorch实战实例VGG深度网络

Python  /  管理员 发布于 5年前   246

模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。

来源

定义模型:

'''VGG11/13/16/19 in Pytorch.'''import torchimport torch.nn as nnfrom torch.autograd import Variablecfg = {  'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],  'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],  'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],  'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],}# 模型需继承nn.Moduleclass VGG(nn.Module):# 初始化参数:  def __init__(self, vgg_name):    super(VGG, self).__init__()    self.features = self._make_layers(cfg[vgg_name])    self.classifier = nn.Linear(512, 10)# 模型计算时的前向过程,也就是按照这个过程进行计算  def forward(self, x):    out = self.features(x)    out = out.view(out.size(0), -1)    out = self.classifier(out)    return out  def _make_layers(self, cfg):    layers = []    in_channels = 3    for x in cfg:      if x == 'M':        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]      else:        layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),  nn.BatchNorm2d(x),  nn.ReLU(inplace=True)]        in_channels = x    layers += [nn.AvgPool2d(kernel_size=1, stride=1)]    return nn.Sequential(*layers)# net = VGG('VGG11')# x = torch.randn(2,3,32,32)# print(net(Variable(x)).size())

定义训练过程:

'''Train CIFAR10 with PyTorch.'''from __future__ import print_functionimport torchimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Fimport torch.backends.cudnn as cudnnimport torchvisionimport torchvision.transforms as transformsimport osimport argparsefrom models import *from utils import progress_barfrom torch.autograd import Variable# 获取参数parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')parser.add_argument('--lr', default=0.1, type=float, help='learning rate')parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')args = parser.parse_args()use_cuda = torch.cuda.is_available()best_acc = 0 # best test accuracystart_epoch = 0 # start from epoch 0 or last checkpoint epoch# 获取数据集,并先进行预处理print('==> Preparing data..')# 图像预处理和增强transform_train = transforms.Compose([  transforms.RandomCrop(32, padding=4),  transforms.RandomHorizontalFlip(),  transforms.ToTensor(),  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])transform_test = transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 继续训练模型或新建一个模型if args.resume:  # Load checkpoint.  print('==> Resuming from checkpoint..')  assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'  checkpoint = torch.load('./checkpoint/ckpt.t7')  net = checkpoint['net']  best_acc = checkpoint['acc']  start_epoch = checkpoint['epoch']else:  print('==> Building model..')  net = VGG('VGG16')  # net = ResNet18()  # net = PreActResNet18()  # net = GoogLeNet()  # net = DenseNet121()  # net = ResNeXt29_2x64d()  # net = MobileNet()  # net = MobileNetV2()  # net = DPN92()  # net = ShuffleNetG2()  # net = SENet18()# 如果GPU可用,使用GPUif use_cuda:  # move param and buffer to GPU  net.cuda()  # parallel use GPU  net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))  # speed up slightly  cudnn.benchmark = True# 定义度量和优化criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)# 训练阶段def train(epoch):  print('\nEpoch: %d' % epoch)  # switch to train mode  net.train()  train_loss = 0  correct = 0  total = 0  # batch 数据  for batch_idx, (inputs, targets) in enumerate(trainloader):    # 将数据移到GPU上    if use_cuda:      inputs, targets = inputs.cuda(), targets.cuda()    # 先将optimizer梯度先置为0    optimizer.zero_grad()    # Variable表示该变量属于计算图的一部分,此处是图计算的开始处。图的leaf variable    inputs, targets = Variable(inputs), Variable(targets)    # 模型输出    outputs = net(inputs)    # 计算loss,图的终点处    loss = criterion(outputs, targets)    # 反向传播,计算梯度    loss.backward()    # 更新参数    optimizer.step()    # 注意如果你想统计loss,切勿直接使用loss相加,而是使用loss.data[0]。因为loss是计算图的一部分,如果你直接加loss,代表total loss同样属于模型一部分,那么图就越来越大    train_loss += loss.data[0]    # 数据统计    _, predicted = torch.max(outputs.data, 1)    total += targets.size(0)    correct += predicted.eq(targets.data).cpu().sum()    progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))# 测试阶段def test(epoch):  global best_acc  # 先切到测试模型  net.eval()  test_loss = 0  correct = 0  total = 0  for batch_idx, (inputs, targets) in enumerate(testloader):    if use_cuda:      inputs, targets = inputs.cuda(), targets.cuda()    inputs, targets = Variable(inputs, volatile=True), Variable(targets)    outputs = net(inputs)    loss = criterion(outputs, targets)    # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.    test_loss += loss.data[0]    _, predicted = torch.max(outputs.data, 1)    total += targets.size(0)    correct += predicted.eq(targets.data).cpu().sum()    progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'      % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))  # Save checkpoint.  # 保存模型  acc = 100.*correct/total  if acc > best_acc:    print('Saving..')    state = {      'net': net.module if use_cuda else net,      'acc': acc,      'epoch': epoch,    }    if not os.path.isdir('checkpoint'):      os.mkdir('checkpoint')    torch.save(state, './checkpoint/ckpt.t7')    best_acc = acc# 运行模型for epoch in range(start_epoch, start_epoch+200):  train(epoch)  test(epoch)  # 清除部分无用变量   torch.cuda.empty_cache()

运行:

新模型:
python main.py --lr=0.01
旧模型继续训练:
python main.py --resume --lr=0.01

一些utility:

'''Some helper functions for PyTorch, including:  - get_mean_and_std: calculate the mean and std value of dataset.  - msr_init: net parameter initialization.  - progress_bar: progress bar mimic xlua.progress.'''import osimport sysimport timeimport mathimport torch.nn as nnimport torch.nn.init as initdef get_mean_and_std(dataset):  '''Compute the mean and std value of dataset.'''  dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)  mean = torch.zeros(3)  std = torch.zeros(3)  print('==> Computing mean and std..')  for inputs, targets in dataloader:    for i in range(3):      mean[i] += inputs[:,i,:,:].mean()      std[i] += inputs[:,i,:,:].std()  mean.div_(len(dataset))  std.div_(len(dataset))  return mean, stddef init_params(net):  '''Init layer parameters.'''  for m in net.modules():    if isinstance(m, nn.Conv2d):      init.kaiming_normal(m.weight, mode='fan_out')      if m.bias:        init.constant(m.bias, 0)    elif isinstance(m, nn.BatchNorm2d):      init.constant(m.weight, 1)      init.constant(m.bias, 0)    elif isinstance(m, nn.Linear):      init.normal(m.weight, std=1e-3)      if m.bias:        init.constant(m.bias, 0)_, term_width = os.popen('stty size', 'r').read().split()term_width = int(term_width)TOTAL_BAR_LENGTH = 65.last_time = time.time()begin_time = last_timedef progress_bar(current, total, msg=None):  global last_time, begin_time  if current == 0:    begin_time = time.time() # Reset for new bar.  cur_len = int(TOTAL_BAR_LENGTH*current/total)  rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1  sys.stdout.write(' [')  for i in range(cur_len):    sys.stdout.write('=')  sys.stdout.write('>')  for i in range(rest_len):    sys.stdout.write('.')  sys.stdout.write(']')  cur_time = time.time()  step_time = cur_time - last_time  last_time = cur_time  tot_time = cur_time - begin_time  L = []  L.append(' Step: %s' % format_time(step_time))  L.append(' | Tot: %s' % format_time(tot_time))  if msg:    L.append(' | ' + msg)  msg = ''.join(L)  sys.stdout.write(msg)  for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):    sys.stdout.write(' ')  # Go back to the center of the bar.  for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):    sys.stdout.write('\b')  sys.stdout.write(' %d/%d ' % (current+1, total))  if current < total-1:    sys.stdout.write('\r')  else:    sys.stdout.write('\n')  sys.stdout.flush()def format_time(seconds):  days = int(seconds / 3600/24)  seconds = seconds - days*3600*24  hours = int(seconds / 3600)  seconds = seconds - hours*3600  minutes = int(seconds / 60)  seconds = seconds - minutes*60  secondsf = int(seconds)  seconds = seconds - secondsf  millis = int(seconds*1000)  f = ''  i = 1  if days > 0:    f += str(days) + 'D'    i += 1  if hours > 0 and i <= 2:    f += str(hours) + 'h'    i += 1  if minutes > 0 and i <= 2:    f += str(minutes) + 'm'    i += 1  if secondsf > 0 and i <= 2:    f += str(secondsf) + 's'    i += 1  if millis > 0 and i <= 2:    f += str(millis) + 'ms'    i += 1  if f == '':    f = '0ms'  return f

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    numpy求平均值的维度设定的例子
    下一条:
    Numpy 中的矩阵求逆实例
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客