侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

Tensorflow分类器项目自定义数据读入的实现

技术  /  管理员 发布于 7年前   160

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import osimport kerasimport matplotlib.pyplot as pltfrom PIL import Imagefrom keras.preprocessing.image import ImageDataGeneratorfrom sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'path = ".\data"# 你也可以使用相对路径的方式# image_path =os.path.join(path, "set")

目录下的结构如下:

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

label_name = "labels.txt"label_path = os.path.join(path, label_name)class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = Falsere_build = False# re_load = Truere_build = Truedata_name = "data.npz"data_path = os.path.join(path, data_name)model_name = "model.h5"model_path = os.path.join(path, model_name)count = 0# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。if not os.path.exists(data_path) or re_load:  labels = []  images = []  print('Handle images')  # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取  for index, name in enumerate(class_names):    # 这里是拼接后的子目录path    classpath = os.path.join(image_path, name)    # 先判断一下是否是目录    if not os.path.isdir(classpath):      continue    # limit是测试时候用的这里可以去除    limit = 0    for image_name in os.listdir(classpath):      if limit >= max_size:        break      # 这里是拼接后的待处理的图片path      imagepath = os.path.join(classpath, image_name)      count = count + 1      limit = limit + 1      # 利用Image打开图片      img = Image.open(imagepath)      # 缩放到你最初确定要处理的图片分辨率大小      img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))      # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量      img = img.convert("L")      # 转为numpy数组      img = np.array(img)      # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放      img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))      # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引      labels.append([index])      # 添加到images中,最后统一处理      images.append(img)      # 循环中一些状态的输出,可以去除      print("{} class: {} {} limit: {} {}"         .format(count, index + 1, class_names[index], limit, imagepath))  # 最后一次性将images和labels都转换成numpy数组  npy_data = np.array(images)  npy_labels = np.array(labels)  # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储  np.savez(data_path, x=npy_data, y=npy_labels)  print("Save images by npz")else:  # 如果存在序列化号的数据,便直接读取,提高速度  npy_data = np.load(data_path)["x"]  npy_labels = np.load(data_path)["y"]  print("Load images by npz")image_data = npy_datalabels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据train_images, test_images, train_labels, test_labels = \  train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")print("%-28s %-s" % ("Name", "Shape"))print("=================================================================")print("%-28s %-s" % ("Image Data", image_data.shape))print("%-28s %-s" % ("Labels Data", labels_data.shape))print("=================================================================")print('Split train and test data,p=%')print("_________________________________________________________________")print("%-28s %-s" % ("Name", "Shape"))print("=================================================================")print("%-28s %-s" % ("Train Images", train_images.shape))print("%-28s %-s" % ("Test Images", test_images.shape))print("%-28s %-s" % ("Train Labels", train_labels.shape))print("%-28s %-s" % ("Test Labels", test_labels.shape))print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")train_images = train_images / 255.0test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import osimport kerasimport matplotlib.pyplot as pltfrom PIL import Imagefrom keras.layers import *from keras.models import *from keras.optimizers import Adamfrom keras.preprocessing.image import ImageDataGeneratorfrom sklearn.model_selection import train_test_splitos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'# 支持中文plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号re_load = Falsere_build = False# re_load = Truere_build = Trueepochs = 50batch_size = 5count = 0max_size = 2000000000

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    使用TensorFlow实现二分类的方法示例
    下一条:
    将string类型的数据类型转换为spark rdd时报错的解决方法
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 2024.07.09日OpenAI将终止对中国等国家和地区API服务(0个评论)
    • 2024/6/9最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 国外服务器实现api.openai.com反代nginx配置(0个评论)
    • 2024/4/28最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2017-07
    • 2017-08
    • 2017-09
    • 2018-01
    • 2018-07
    • 2018-08
    • 2018-09
    • 2018-12
    • 2019-01
    • 2019-02
    • 2019-03
    • 2019-04
    • 2019-05
    • 2019-06
    • 2019-07
    • 2019-08
    • 2019-09
    • 2019-10
    • 2019-11
    • 2019-12
    • 2020-01
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2020-07
    • 2020-08
    • 2020-09
    • 2020-10
    • 2020-11
    • 2021-04
    • 2021-05
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-12
    • 2022-01
    • 2022-02
    • 2022-03
    • 2022-04
    • 2022-05
    • 2022-06
    • 2022-07
    • 2022-08
    • 2022-09
    • 2022-10
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-04
    • 2023-05
    • 2023-06
    • 2023-07
    • 2023-08
    • 2023-09
    • 2023-10
    • 2023-12
    • 2024-02
    • 2024-04
    • 2024-05
    • 2024-06
    • 2025-02
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客