侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

使用pytorch实现可视化中间层的结果

Python  /  管理员 发布于 5年前   465

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

下面是完整的代码

import cv2import numpy as npimport torchfrom torch.autograd import Variablefrom torchvision import models#创建30个文件夹def mkdir(path): # 判断是否存在指定文件夹,不存在则创建  # 引入模块  import os  # 去除首位空格  path = path.strip()  # 去除尾部 \ 符号  path = path.rstrip("\\")  # 判断路径是否存在  # 存在   True  # 不存在  False  isExists = os.path.exists(path)  # 判断结果  if not isExists:    # 如果不存在则创建目录    # 创建目录操作函数    os.makedirs(path)    return True  else:    return Falsedef preprocess_image(cv2im, resize_im=True):  """    Processes image for CNNs  Args:    PIL_img (PIL_img): Image to process    resize_im (bool): Resize to 224 or not  returns:    im_as_var (Pytorch variable): Variable that contains processed float tensor  """  # mean and std list for channels (Imagenet)  mean = [0.485, 0.456, 0.406]  std = [0.229, 0.224, 0.225]  # Resize image  if resize_im:    cv2im = cv2.resize(cv2im, (224, 224))  im_as_arr = np.float32(cv2im)  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H  # Normalize the channels  for channel, _ in enumerate(im_as_arr):    im_as_arr[channel] /= 255    im_as_arr[channel] -= mean[channel]    im_as_arr[channel] /= std[channel]  # Convert to float tensor  im_as_ten = torch.from_numpy(im_as_arr).float()  # Add one more channel to the beginning. Tensor shape = 1,3,224,224  im_as_ten.unsqueeze_(0)  # Convert to Pytorch variable  im_as_var = Variable(im_as_ten, requires_grad=True)  return im_as_varclass FeatureVisualization():  def __init__(self,img_path,selected_layer):    self.img_path=img_path    self.selected_layer=selected_layer    self.pretrained_model = models.vgg16(pretrained=True).features    #print( self.pretrained_model)  def process_image(self):    img=cv2.imread(self.img_path)    img=preprocess_image(img)    return img  def get_feature(self):    # input = Variable(torch.randn(1, 3, 224, 224))    input=self.process_image()    print("input shape",input.shape)    x=input    for index,layer in enumerate(self.pretrained_model):      #print(index)      #print(layer)      x=layer(x)      if (index == self.selected_layer):        return x  def get_single_feature(self):    features=self.get_feature()    print("features.shape",features.shape)    feature=features[:,0,:,:]    print(feature.shape)    feature=feature.view(feature.shape[1],feature.shape[2])    print(feature.shape)    return features  def save_feature_to_img(self):    #to numpy    features=self.get_single_feature()    for i in range(features.shape[1]):      feature = features[:, i, :, :]      feature = feature.view(feature.shape[1], feature.shape[2])      feature = feature.data.numpy()      # use sigmod to [0,1]      feature = 1.0 / (1 + np.exp(-1 * feature))      # to [0,255]      feature = np.round(feature * 255)      print(feature[0])      mkdir('./feature/' + str(self.selected_layer))      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)if __name__=='__main__':  # get class  for k in range(30):    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)    print (myClass.pretrained_model)    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


  • 上一条:
    pytorch 批次遍历数据集打印数据的例子
    下一条:
    在Pytorch中计算自己模型的FLOPs方式
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客