侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

PyTorch搭建一维线性回归模型(二)

Python  /  管理员 发布于 5年前   473

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集,线性回归希望能够优化出一个好的函数,使得能够和尽可能接近。

如何才能学习到参数和呢?很简单,只需要确定如何衡量与之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到和,使得:

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torchimport matplotlib.pyplot as plt  x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)y = 3*x + 10 + torch.rand(x.size())# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音 # 画图plt.scatter(x.data.numpy(), y.data.numpy())plt.show()

我们想要拟合的一维回归模型是。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):  def __init__(self):    super(LinearRegression, self).__init__()    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1  def forward(self, x):    out = self.linear(x)    return out if torch.cuda.is_available():  model = LinearRegression().cuda()else:  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000for epoch in range(num_epochs):  if torch.cuda.is_available():    inputs = Variable(x).cuda()    target = Variable(y).cuda()  else:    inputs = Variable(x)    target = Variable(y)   # 向前传播  out = model(inputs)  loss = criterion(out, target)   # 向后传播  optimizer.zero_grad() # 注意每次迭代都需要清零  loss.backward()  optimizer.step()   if (epoch+1) %20 == 0:    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()if torch.cuda.is_available():  predict = model(Variable(x).cuda())  predict = predict.data.cpu().numpy()else:  predict = model(Variable(x))  predict = predict.data.numpy()plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')plt.plot(x.numpy(), predict, label='Fitting Line')plt.show()

其拟合结果如下图:

附上完整代码:

# !/usr/bin/python# coding: utf8# @Time  : 2018-07-28 18:40# @Author : Liam# @Email  : [email protected]# @Software: PyCharm#.::::.#           .::::::::.#           :::::::::::#         ..:::::::::::'#        '::::::::::::'#         .::::::::::#      '::::::::::::::..#         ..::::::::::::.#        ``::::::::::::::::#        ::::``:::::::::'    .:::.#        ::::'  ':::::'    .::::::::.#       .::::'   ::::   .:::::::'::::.#      .:::'    ::::: .:::::::::' ':::::.#      .::'    :::::.:::::::::'   ':::::.#     .::'     ::::::::::::::'     ``::::.#   ...:::      ::::::::::::'       ``::.#   ```` ':.     ':::::::::'         ::::..#'.:::::'          ':'````..#           美女保佑 永无BUG import torchfrom torch.autograd import Variableimport numpy as npimport randomimport matplotlib.pyplot as pltfrom torch import nn  x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)y = 3*x + 10 + torch.rand(x.size())# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音 # 画图# plt.scatter(x.data.numpy(), y.data.numpy())# plt.show()class LinearRegression(nn.Module):  def __init__(self):    super(LinearRegression, self).__init__()    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1  def forward(self, x):    out = self.linear(x)    return out if torch.cuda.is_available():  model = LinearRegression().cuda()else:  model = LinearRegression() criterion = nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) num_epochs = 1000for epoch in range(num_epochs):  if torch.cuda.is_available():    inputs = Variable(x).cuda()    target = Variable(y).cuda()  else:    inputs = Variable(x)    target = Variable(y)   # 向前传播  out = model(inputs)  loss = criterion(out, target)   # 向后传播  optimizer.zero_grad() # 注意每次迭代都需要清零  loss.backward()  optimizer.step()   if (epoch+1) %20 == 0:    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))model.eval()if torch.cuda.is_available():  predict = model(Variable(x).cuda())  predict = predict.data.cpu().numpy()else:  predict = model(Variable(x))  predict = predict.data.numpy()plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')plt.plot(x.numpy(), predict, label='Fitting Line')plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    PyQt5的PyQtGraph实践系列3之实时数据更新绘制图形
    下一条:
    PyTorch基本数据类型(一)
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客