侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

使用tensorflow实现AlexNet

技术  /  管理员 发布于 7年前   180

AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中所使用的很多技巧,而且非常有助于提升我们使用深度学习工具箱的熟练度。尤其是我刚入门深度学习,迫切需要一个能让自己熟悉tensorflow的小练习,于是就有了这个小玩意儿......

先放上我的代码:https://github.com/hjptriplebee/AlexNet_with_tensorflow

如果想运行代码,详细的配置要求都在上面链接的readme文件中了。本文建立在一定的tensorflow基础上,不会对太细的点进行说明。

模型结构

关于模型结构网上的文献很多,我这里不赘述,一会儿都在代码里解释。

有一点需要注意,AlexNet将网络分成了上下两个部分,在论文中两部分结构完全相同,唯一不同的是他们放在不同GPU上训练,因为每一层的feature map之间都是独立的(除了全连接层),所以这相当于是提升训练速度的一种方法。很多AlexNet的复现都将上下两部分合并了,因为他们都是在单个GPU上运行的。虽然我也是在单个GPU上运行,但是我还是很想将最原始的网络结构还原出来,所以我的代码里也是分开的。

模型定义

def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"):   """max-pooling"""   return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1],  strides = [1, strideX, strideY, 1], padding = padding, name = name)  def dropout(x, keepPro, name = None):   """dropout"""   return tf.nn.dropout(x, keepPro, name)  def LRN(x, R, alpha, beta, name = None, bias = 1.0):   """LRN"""   return tf.nn.local_response_normalization(x, depth_radius = R, alpha = alpha,beta = beta, bias = bias, name = name)  def fcLayer(x, inputD, outputD, reluFlag, name):   """fully-connect"""   with tf.variable_scope(name) as scope:     w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float")     b = tf.get_variable("b", [outputD], dtype = "float")     out = tf.nn.xw_plus_b(x, w, b, name = scope.name)     if reluFlag:       return tf.nn.relu(out)     else:       return out  def convLayer(x, kHeight, kWidth, strideX, strideY,        featureNum, name, padding = "SAME", groups = 1):#group为2时等于AlexNet中分上下两部分   """convlutional"""   channel = int(x.get_shape()[-1])#获取channel   conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding)#定义卷积的匿名函数   with tf.variable_scope(name) as scope:     w = tf.get_variable("w", shape = [kHeight, kWidth, channel/groups, featureNum])     b = tf.get_variable("b", shape = [featureNum])      xNew = tf.split(value = x, num_or_size_splits = groups, axis = 3)#划分后的输入和权重     wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3)      featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)] #分别提取feature map     mergeFeatureMap = tf.concat(axis = 3, values = featureMap) #feature map整合     # print mergeFeatureMap.shape     out = tf.nn.bias_add(mergeFeatureMap, b)     return tf.nn.relu(tf.reshape(out, mergeFeatureMap.get_shape().as_list()), name = scope.name) #relu后的结果

定义了卷积、pooling、LRN、dropout、全连接五个模块,其中卷积模块因为将网络的上下两部分分开了,所以比较复杂。接下来定义AlexNet。

class alexNet(object):   """alexNet model"""   def __init__(self, x, keepPro, classNum, skip, modelPath = "bvlc_alexnet.npy"):     self.X = x     self.KEEPPRO = keepPro     self.CLASSNUM = classNum     self.SKIP = skip     self.MODELPATH = modelPath     #build CNN     self.buildCNN()    def buildCNN(self):     """build model"""     conv1 = convLayer(self.X, 11, 11, 4, 4, 96, "conv1", "VALID")     pool1 = maxPoolLayer(conv1, 3, 3, 2, 2, "pool1", "VALID")     lrn1 = LRN(pool1, 2, 2e-05, 0.75, "norm1")      conv2 = convLayer(lrn1, 5, 5, 1, 1, 256, "conv2", groups = 2)     pool2 = maxPoolLayer(conv2, 3, 3, 2, 2, "pool2", "VALID")     lrn2 = LRN(pool2, 2, 2e-05, 0.75, "lrn2")      conv3 = convLayer(lrn2, 3, 3, 1, 1, 384, "conv3")      conv4 = convLayer(conv3, 3, 3, 1, 1, 384, "conv4", groups = 2)      conv5 = convLayer(conv4, 3, 3, 1, 1, 256, "conv5", groups = 2)     pool5 = maxPoolLayer(conv5, 3, 3, 2, 2, "pool5", "VALID")      fcIn = tf.reshape(pool5, [-1, 256 * 6 * 6])     fc1 = fcLayer(fcIn, 256 * 6 * 6, 4096, True, "fc6")     dropout1 = dropout(fc1, self.KEEPPRO)      fc2 = fcLayer(dropout1, 4096, 4096, True, "fc7")     dropout2 = dropout(fc2, self.KEEPPRO)      self.fc3 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8")    def loadModel(self, sess):     """load model"""     wDict = np.load(self.MODELPATH, encoding = "bytes").item()     #for layers in model     for name in wDict:       if name not in self.SKIP:         with tf.variable_scope(name, reuse = True):           for p in wDict[name]: if len(p.shape) == 1:    #bias 只有一维   sess.run(tf.get_variable('b', trainable = False).assign(p)) else:   #weights    sess.run(tf.get_variable('w', trainable = False).assign(p)) 

buildCNN函数完全按照alexnet的结构搭建网络。
loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。
至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的AlexNet有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的渣土车图片:

然后编写测试代码:

#some params dropoutPro = 1 classNum = 1000 skip = [] #get testImage testPath = "testModel" testImg = [] for f in os.listdir(testPath):   testImg.append(cv2.imread(testPath + "/" + f))  imgMean = np.array([104, 117, 124], np.float) x = tf.placeholder("float", [1, 227, 227, 3])  model = alexnet.alexNet(x, dropoutPro, classNum, skip) score = model.fc3 softmax = tf.nn.softmax(score)  with tf.Session() as sess:   sess.run(tf.global_variables_initializer())   model.loadModel(sess) #加载模型    for i, img in enumerate(testImg):     #img preprocess     test = cv2.resize(img.astype(np.float), (227, 227)) #resize成网络输入大小     test -= imgMean #去均值     test = test.reshape((1, 227, 227, 3)) #拉成tensor     maxx = np.argmax(sess.run(softmax, feed_dict = {x: test}))     res = caffe_classes.class_names[maxx] #取概率最大类的下标     #print(res)     font = cv2.FONT_HERSHEY_SIMPLEX     cv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2)#绘制类的名字     cv2.imshow("demo", img)      cv2.waitKey(5000) #显示5秒 

如上代码所示,首先需要设置一些参数,然后读取指定路径下的测试图像,再对模型做一个初始化,最后是真正测试代码。测试结果如下:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    详解supervisor使用教程
    下一条:
    shell脚本实现的网站日志分析统计(可以统计9种数据)
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 2024.07.09日OpenAI将终止对中国等国家和地区API服务(0个评论)
    • 2024/6/9最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 国外服务器实现api.openai.com反代nginx配置(0个评论)
    • 2024/4/28最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2017-07
    • 2017-08
    • 2017-09
    • 2018-01
    • 2018-07
    • 2018-08
    • 2018-09
    • 2018-12
    • 2019-01
    • 2019-02
    • 2019-03
    • 2019-04
    • 2019-05
    • 2019-06
    • 2019-07
    • 2019-08
    • 2019-09
    • 2019-10
    • 2019-11
    • 2019-12
    • 2020-01
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2020-07
    • 2020-08
    • 2020-09
    • 2020-10
    • 2020-11
    • 2021-04
    • 2021-05
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-12
    • 2022-01
    • 2022-02
    • 2022-03
    • 2022-04
    • 2022-05
    • 2022-06
    • 2022-07
    • 2022-08
    • 2022-09
    • 2022-10
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-04
    • 2023-05
    • 2023-06
    • 2023-07
    • 2023-08
    • 2023-09
    • 2023-10
    • 2023-12
    • 2024-02
    • 2024-04
    • 2024-05
    • 2024-06
    • 2025-02
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客