侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

使用TensorFlow-Slim进行图像分类的实现

技术  /  管理员 发布于 7年前   419

参考 https://github.com/tensorflow/models/tree/master/slim

使用TensorFlow-Slim进行图像分类

准备

安装TensorFlow

参考 https://www.tensorflow.org/install/

如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whlpip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl

下载TF-slim图像模型库

cd $WORKSPACEgit clone https://github.com/tensorflow/models/

准备数据

有不少公开数据集,这里以官网提供的Flowers为例。

官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。

cd $WORKSPACE/datawget http://download.tensorflow.org/example_images/flower_photos.tgztar zxf flower_photos.tgz

数据集文件夹结构如下:

flower_photos├── daisy│  ├── 100080576_f52e8ee070_n.jpg│  └── ...├── dandelion├── LICENSE.txt├── roses├── sunflowers└── tulips

由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。

Python代码:

import osclass_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}data_dir = 'flower_photos/'output_path = 'list.txt'fd = open(output_path, 'w')for class_name in class_names_to_ids.keys():  images_list = os.listdir(data_dir + class_name)  for image_name in images_list:    fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))fd.close()

为了方便后期查看label标签,也可以定义labels.txt:

daisydandelionrosessunflowerstulips

随机生成训练集与验证集:

Python代码:

import random_NUM_VALIDATION = 350_RANDOM_SEED = 0list_path = 'list.txt'train_list_path = 'list_train.txt'val_list_path = 'list_val.txt'fd = open(list_path)lines = fd.readlines()fd.close()random.seed(_RANDOM_SEED)random.shuffle(lines)fd = open(train_list_path, 'w')for line in lines[_NUM_VALIDATION:]:  fd.write(line)fd.close()fd = open(val_list_path, 'w')for line in lines[:_NUM_VALIDATION]:  fd.write(line)fd.close()

生成TFRecord数据:

Python代码:

import syssys.path.insert(0, '../models/slim/')from datasets import dataset_utilsimport mathimport osimport tensorflow as tfdef convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):  fd = open(list_path)  lines = [line.split() for line in fd]  fd.close()  num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))  with tf.Graph().as_default():    decode_jpeg_data = tf.placeholder(dtype=tf.string)    decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)    with tf.Session('') as sess:      for shard_id in range(_NUM_SHARDS):        output_path = os.path.join(output_dir,          'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))        tfrecord_writer = tf.python_io.TFRecordWriter(output_path)        start_ndx = shard_id * num_per_shard        end_ndx = min((shard_id + 1) * num_per_shard, len(lines))        for i in range(start_ndx, end_ndx):          sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(i + 1, len(lines), shard_id))          sys.stdout.flush()          image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()          image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})          height, width = image.shape[0], image.shape[1]          example = dataset_utils.image_to_tfexample(image_data, b'jpg', height, width, int(lines[i][1]))          tfrecord_writer.write(example.SerializeToString())        tfrecord_writer.close()  sys.stdout.write('\n')  sys.stdout.flush()os.system('mkdir -p train')convert_dataset('list_train.txt', 'flower_photos', 'train/')os.system('mkdir -p val')convert_dataset('list_val.txt', 'flower_photos', 'val/')

得到的文件夹结构如下:

data├── flower_photos├── labels.txt├── list_train.txt├── list.txt├── list_val.txt├── train│  ├── data_00000-of-00005.tfrecord│  ├── ...│  └── data_00004-of-00005.tfrecord└── val  ├── data_00000-of-00005.tfrecord  ├── ...  └── data_00004-of-00005.tfrecord

(可选)下载模型

官方提供了不少预训练模型,这里以Inception-ResNet-v2以例。

cd $WORKSPACE/checkpointswget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gztar zxf inception_resnet_v2_2016_08_30.tar.gz

训练

读入数据

官方提供了读入Flowers数据集的代码models/slim/datasets/flowers.py,同样这里也是参考并修改成能读入上面定义的通用数据集。

把下面代码写入models/slim/datasets/dataset_classification.py。

import osimport tensorflow as tfslim = tf.contrib.slimdef get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):  file_pattern = os.path.join(dataset_dir, file_pattern)  keys_to_features = {    'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),    'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),    'image/class/label': tf.FixedLenFeature(      [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),  }  items_to_handlers = {    'image': slim.tfexample_decoder.Image(),    'label': slim.tfexample_decoder.Tensor('image/class/label'),  }  decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)  items_to_descriptions = {    'image': 'A color image of varying size.',    'label': 'A single integer between 0 and ' + str(num_classes - 1),  }  labels_to_names = None  if labels_to_names_path is not None:    fd = open(labels_to_names_path)    labels_to_names = {i : line.strip() for i, line in enumerate(fd)}    fd.close()  return slim.dataset.Dataset(      data_sources=file_pattern,      reader=tf.TFRecordReader,      decoder=decoder,      num_samples=num_samples,      items_to_descriptions=items_to_descriptions,      num_classes=num_classes,      labels_to_names=labels_to_names)

构建模型

官方提供了许多模型在models/slim/nets/。

如需要自定义模型,则参考官方提供的模型并放在对应的文件夹即可。

开始训练

官方提供了训练脚本,如果使用官方的数据读入和处理,可使用以下方式开始训练。

cd $WORKSPACE/models/slimCUDA_VISIBLE_DEVICES="0" python train_image_classifier.py \  --train_dir=train_logs \  --dataset_name=flowers \  --dataset_split_name=train \  --dataset_dir=../../data/flowers \  --model_name=inception_resnet_v2 \  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \  --max_number_of_steps=1000 \  --batch_size=32 \  --learning_rate=0.01 \  --learning_rate_decay_type=fixed \  --save_interval_secs=60 \  --save_summaries_secs=60 \  --log_every_n_steps=10 \  --optimizer=rmsprop \  --weight_decay=0.00004

不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes和--trainable_scopes删掉。

fine-tune所有层把--checkpoint_exclude_scopes和--trainable_scopes删掉。

如果只使用CPU则加上--clone_on_cpu=True。

其它参数可删掉用默认值或自行修改。

使用自己的数据则需要修改models/slim/train_image_classifier.py:

把

from datasets import dataset_factory

修改为

from datasets import dataset_classification

把

dataset = dataset_factory.get_dataset(  FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(  FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

在

tf.app.flags.DEFINE_string(  'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(  'num_samples', 3320, 'Number of samples.')tf.app.flags.DEFINE_integer(  'num_classes', 5, 'Number of classes.')tf.app.flags.DEFINE_string(  'labels_to_names_path', None, 'Label names file path.')

训练时执行以下命令即可:

cd $WORKSPACE/models/slimpython train_image_classifier.py \  --train_dir=train_logs \  --dataset_dir=../../data/train \  --num_samples=3320 \  --num_classes=5 \  --labels_to_names_path=../../data/labels.txt \  --model_name=inception_resnet_v2 \  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits

可视化log

可一边训练一边可视化训练的log,可看到Loss趋势。

tensorboard --logdir train_logs/

验证

官方提供了验证脚本。

python eval_image_classifier.py \  --checkpoint_path=train_logs \  --eval_dir=eval_logs \  --dataset_name=flowers \  --dataset_split_name=validation \  --dataset_dir=../../data/flowers \  --model_name=inception_resnet_v2

同样,如果是使用自己的数据集,则需要修改models/slim/eval_image_classifier.py:

把

from datasets import dataset_factory

修改为

from datasets import dataset_classification

把

dataset = dataset_factory.get_dataset(  FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(  FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

在

tf.app.flags.DEFINE_string(  'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(  'num_samples', 350, 'Number of samples.')tf.app.flags.DEFINE_integer(  'num_classes', 5, 'Number of classes.')tf.app.flags.DEFINE_string(  'labels_to_names_path', None, 'Label names file path.')

验证时执行以下命令即可:

python eval_image_classifier.py \  --checkpoint_path=train_logs \  --eval_dir=eval_logs \  --dataset_dir=../../data/val \  --num_samples=350 \  --num_classes=5 \  --model_name=inception_resnet_v2

可以一边训练一边验证,,注意使用其它的GPU或合理分配显存。

同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:

tensorboard --logdir eval_logs/ --port 6007

测试

参考models/slim/eval_image_classifier.py,可编写读取图片用模型进行推导的脚本models/slim/test_image_classifier.py

from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport osimport mathimport tensorflow as tffrom nets import nets_factoryfrom preprocessing import preprocessing_factoryslim = tf.contrib.slimtf.app.flags.DEFINE_string(  'master', '', 'The address of the TensorFlow master to use.')tf.app.flags.DEFINE_string(  'checkpoint_path', '/tmp/tfmodel/',  'The directory where the model was written to or an absolute path to a '  'checkpoint file.')tf.app.flags.DEFINE_string(  'test_path', '', 'Test image path.')tf.app.flags.DEFINE_integer(  'num_classes', 5, 'Number of classes.')tf.app.flags.DEFINE_integer(  'labels_offset', 0,  'An offset for the labels in the dataset. This flag is primarily used to '  'evaluate the VGG and ResNet architectures which do not use a background '  'class for the ImageNet dataset.')tf.app.flags.DEFINE_string(  'model_name', 'inception_v3', 'The name of the architecture to evaluate.')tf.app.flags.DEFINE_string(  'preprocessing_name', None, 'The name of the preprocessing to use. If left '  'as `None`, then the model_name flag is used.')tf.app.flags.DEFINE_integer(  'test_image_size', None, 'Eval image size')FLAGS = tf.app.flags.FLAGSdef main(_):  if not FLAGS.test_list:    raise ValueError('You must supply the test list with --test_list')  tf.logging.set_verbosity(tf.logging.INFO)  with tf.Graph().as_default():    tf_global_step = slim.get_or_create_global_step()    ####################    # Select the model #    ####################    network_fn = nets_factory.get_network_fn(      FLAGS.model_name,      num_classes=(FLAGS.num_classes - FLAGS.labels_offset),      is_training=False)    #####################################    # Select the preprocessing function #    #####################################    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name    image_preprocessing_fn = preprocessing_factory.get_preprocessing(      preprocessing_name,      is_training=False)    test_image_size = FLAGS.test_image_size or network_fn.default_image_size    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)    else:      checkpoint_path = FLAGS.checkpoint_path    tf.Graph().as_default()    with tf.Session() as sess:      image = open(FLAGS.test_path, 'rb').read()      image = tf.image.decode_jpeg(image, channels=3)      processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)      processed_images = tf.expand_dims(processed_image, 0)      logits, _ = network_fn(processed_images)      predictions = tf.argmax(logits, 1)      saver = tf.train.Saver()      saver.restore(sess, checkpoint_path)      np_image, network_input, predictions = sess.run([image, processed_image, predictions])      print('{} {}'.format(FLAGS.test_path, predictions[0]))if __name__ == '__main__':  tf.app.run()

测试时执行以下命令即可:

python test_image_classifier.py \  --checkpoint_path=train_logs/ \  --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg \  --num_classes=5 \  --model_name=inception_resnet_v2

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。


  • 上一条:
    pandas 对group进行聚合的例子
    下一条:
    使用pickle存储数据dump 和 load实例讲解
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 2024.07.09日OpenAI将终止对中国等国家和地区API服务(0个评论)
    • 2024/6/9最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 国外服务器实现api.openai.com反代nginx配置(0个评论)
    • 2024/4/28最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2017-07
    • 2017-08
    • 2017-09
    • 2018-01
    • 2018-07
    • 2018-08
    • 2018-09
    • 2018-12
    • 2019-01
    • 2019-02
    • 2019-03
    • 2019-04
    • 2019-05
    • 2019-06
    • 2019-07
    • 2019-08
    • 2019-09
    • 2019-10
    • 2019-11
    • 2019-12
    • 2020-01
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2020-07
    • 2020-08
    • 2020-09
    • 2020-10
    • 2020-11
    • 2021-04
    • 2021-05
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-12
    • 2022-01
    • 2022-02
    • 2022-03
    • 2022-04
    • 2022-05
    • 2022-06
    • 2022-07
    • 2022-08
    • 2022-09
    • 2022-10
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-04
    • 2023-05
    • 2023-06
    • 2023-07
    • 2023-08
    • 2023-09
    • 2023-10
    • 2023-12
    • 2024-02
    • 2024-04
    • 2024-05
    • 2024-06
    • 2025-02
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客