侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

pytorch之inception_v3的实现案例

Python  /  管理员 发布于 5年前   378

如下所示:

from __future__ import print_function from __future__ import divisionimport torchimport torch.nn as nnimport torch.optim as optimimport numpy as npimport torchvisionfrom torchvision import datasets, models, transformsimport matplotlib.pyplot as pltimport timeimport osimport copyimport argparseprint("PyTorch Version: ",torch.__version__)print("Torchvision Version: ",torchvision.__version__)# Top level data directory. Here we assume the format of the directory conforms #  to the ImageFolder structure

数据集路径,路径下的数据集分为训练集和测试集,也就是train 以及val,train下分为两类数据1,2,val集同理

data_dir = "/home/dell/Desktop/data/切割图像"# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]model_name = "inception" # Number of classes in the datasetnum_classes = 2#两类数据1,2# Batch size for training (change depending on how much memory you have)batch_size = 32#batchsize尽量选取合适,否则训练时会内存溢出# Number of epochs to train for num_epochs = 1000# Flag for feature extracting. When False, we finetune the whole model, #  when True we only update the reshaped layer paramsfeature_extract = True# 参数设置,使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多parser = argparse.ArgumentParser(description='PyTorch inception')parser.add_argument('--outf', default='/home/dell/Desktop/dj/inception/', help='folder to output images and model checkpoints') #输出结果保存路径parser.add_argument('--net', default='/home/dell/Desktop/dj/inception/inception.pth', help="path to net (to continue training)") #恢复训练时的模型路径args = parser.parse_args()


训练函数

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,is_inception=False):  since = time.time()  val_acc_history = []    best_model_wts = copy.deepcopy(model.state_dict())  best_acc = 0.0  print("Start Training, InceptionV3!")   with open("acc.txt", "w") as f1:    with open("log.txt", "w")as f2:      for epoch in range(num_epochs):        print('Epoch {}/{}'.format(epoch+1, num_epochs))        print('*' * 10)        # Each epoch has a training and validation phase        for phase in ['train', 'val']:          if phase == 'train':model.train() # Set model to training mode          else:model.eval()  # Set model to evaluate mode  running_loss = 0.0          running_corrects = 0  # Iterate over data.          for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)    # zero the parameter gradientsoptimizer.zero_grad()    # forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):    if is_inception and phase == 'train':    # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958    outputs, aux_outputs = model(inputs)    loss1 = criterion(outputs, labels)    loss2 = criterion(aux_outputs, labels)    loss = loss1 + 0.4*loss2  else:    outputs = model(inputs)    loss = criterion(outputs, labels)      _, preds = torch.max(outputs, 1)      # backward + optimize only if in training phase  if phase == 'train':    loss.backward()    optimizer.step()    # statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)          epoch_loss = running_loss / len(dataloaders[phase].dataset)          epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)  print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))          f2.write('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))          f2.write('\n')          f2.flush()         # deep copy the model          if phase == 'val':if (epoch+1)%50==0:  #print('Saving model......')  torch.save(model.state_dict(), '%s/inception_%03d.pth' % (args.outf, epoch + 1))f1.write("EPOCH=%03d,Accuracy= %.3f%%" % (epoch + 1, epoch_acc))f1.write('\n')f1.flush()          if phase == 'val' and epoch_acc > best_acc:f3 = open("best_acc.txt", "w")f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1,epoch_acc))f3.close()best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())          if phase == 'val':val_acc_history.append(epoch_acc)  time_elapsed = time.time() - since  print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))  print('Best val Acc: {:4f}'.format(best_acc))  # load best model weights  model.load_state_dict(best_model_wts)  return model, val_acc_history #是否更新参数def set_parameter_requires_grad(model, feature_extracting):  if feature_extracting:    for param in model.parameters():      param.requires_grad = Falsedef initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):  # Initialize these variables which will be set in this if statement. Each of these  #  variables is model specific.  model_ft = None  input_size = 0  if model_name == "resnet":    """ Resnet18    """    model_ft = models.resnet18(pretrained=use_pretrained)    set_parameter_requires_grad(model_ft, feature_extract)    num_ftrs = model_ft.fc.in_features    model_ft.fc = nn.Linear(num_ftrs, num_classes)    input_size = 224  elif model_name == "alexnet":    """ Alexnet    """    model_ft = models.alexnet(pretrained=use_pretrained)    set_parameter_requires_grad(model_ft, feature_extract)    num_ftrs = model_ft.classifier[6].in_features    model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)    input_size = 224  elif model_name == "vgg":    """ VGG11_bn    """    model_ft = models.vgg11_bn(pretrained=use_pretrained)    set_parameter_requires_grad(model_ft, feature_extract)    num_ftrs = model_ft.classifier[6].in_features    model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)    input_size = 224  elif model_name == "squeezenet":    """ Squeezenet    """    model_ft = models.squeezenet1_0(pretrained=use_pretrained)    set_parameter_requires_grad(model_ft, feature_extract)    model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))    model_ft.num_classes = num_classes    input_size = 224  elif model_name == "densenet":    """ Densenet    """    model_ft = models.densenet121(pretrained=use_pretrained)    set_parameter_requires_grad(model_ft, feature_extract)    num_ftrs = model_ft.classifier.in_features    model_ft.classifier = nn.Linear(num_ftrs, num_classes)     input_size = 224  elif model_name == "inception":    """ Inception v3     Be careful, expects (299,299) sized images and has auxiliary output    """    model_ft = models.inception_v3(pretrained=use_pretrained)    set_parameter_requires_grad(model_ft, feature_extract)    # Handle the auxilary net    num_ftrs = model_ft.AuxLogits.fc.in_features    model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)    # Handle the primary net    num_ftrs = model_ft.fc.in_features    model_ft.fc = nn.Linear(num_ftrs,num_classes)    input_size = 299  else:    print("Invalid model name, exiting...")    exit()    return model_ft, input_size# Initialize the model for this runmodel_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)# Print the model we just instantiated#print(model_ft) #准备数据data_transforms = {  'train': transforms.Compose([    transforms.RandomResizedCrop(input_size),    transforms.RandomHorizontalFlip(),    transforms.ToTensor(),    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  ]),  'val': transforms.Compose([    transforms.Resize(input_size),    transforms.CenterCrop(input_size),    transforms.ToTensor(),    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  ]),}print("Initializing Datasets and Dataloaders...")# Create training and validation datasetsimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}# Create training and validation dataloadersdataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=0) for x in ['train', 'val']}# Detect if we have a GPU availabledevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")'''是否加载之前训练过的模型we='/home/dell/Desktop/dj/inception_050.pth'model_ft.load_state_dict(torch.load(we))'''# Send the model to GPUmodel_ft = model_ft.to(device)params_to_update = model_ft.parameters()print("Params to learn:")if feature_extract:  params_to_update = []  for name,param in model_ft.named_parameters():    if param.requires_grad == True:      params_to_update.append(param)      print("\t",name)else:  for name,param in model_ft.named_parameters():    if param.requires_grad == True:      print("\t",name)# Observe that all parameters are being optimizedoptimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs#exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.95)# Setup the loss fxncriterion = nn.CrossEntropyLoss()# Train and evaluatemodel_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))'''#随机初始化时的训练程序# Initialize the non-pretrained version of the model used for this runscratch_model,_ = initialize_model(model_name, num_classes, feature_extract=False, use_pretrained=False)scratch_model = scratch_model.to(device)scratch_optimizer = optim.SGD(scratch_model.parameters(), lr=0.001, momentum=0.9)scratch_criterion = nn.CrossEntropyLoss()_,scratch_hist = train_model(scratch_model, dataloaders_dict, scratch_criterion, scratch_optimizer, num_epochs=num_epochs, is_inception=(model_name=="inception"))# Plot the training curves of validation accuracy vs. number # of training epochs for the transfer learning method and# the model trained from scratchohist = []shist = []ohist = [h.cpu().numpy() for h in hist]shist = [h.cpu().numpy() for h in scratch_hist]plt.title("Validation Accuracy vs. Number of Training Epochs")plt.xlabel("Training Epochs")plt.ylabel("Validation Accuracy")plt.plot(range(1,num_epochs+1),ohist,label="Pretrained")plt.plot(range(1,num_epochs+1),shist,label="Scratch")plt.ylim((0,1.))plt.xticks(np.arange(1, num_epochs+1, 1.0))plt.legend()plt.show()'''

以上这篇pytorch之inception_v3的实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


  • 上一条:
    pytorch之ImageFolder使用详解
    下一条:
    pytorch之添加BN的实现
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 智能合约Solidity学习CryptoZombie第三课:组建僵尸军队(高级Solidity理论)(0个评论)
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客