侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

tensorflow识别自己手写数字

技术  /  管理员 发布于 7年前   266

tensorflow作为google开源的项目,现在赶超了caffe,好像成为最受欢迎的深度学习框架。确实在编写的时候更能感受到代码的真实存在,这点和caffe不同,caffe通过编写配置文件进行网络的生成。环境tensorflow是0.10的版本,注意其他版本有的语句会有错误,这是tensorflow版本之间的兼容问题。

还需要安装PIL:pip install Pillow

图片的格式: 

C 图像标准化,可安装在20×20像素的框内,同时保留其长宽比。
C 图片都集中在一个28×28的图像中。
C 像素以列为主进行排序。像素值0到255,0表示背景(白色),255表示前景(黑色)。

创建一个.png的文件,背景是白色的,手写的字体是黑色的,

下面是数据测试的代码,一个两层的卷积神经网,然后用save进行模型的保存。

# coding: UTF-8 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import input_data ''''' 得到数据 ''' mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  training = mnist.train.images trainlable = mnist.train.labels testing = mnist.test.images testlabel = mnist.test.labels  print ("MNIST loaded") # 获取交互式的方式 sess = tf.InteractiveSession() # 初始化变量 x = tf.placeholder("float", shape=[None, 784]) y_ = tf.placeholder("float", shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) ''''' 生成权重函数,其中shape是数据的形状 ''' def weight_variable(shape):   initial = tf.truncated_normal(shape, stddev=0.1)   return tf.Variable(initial) ''''' 生成偏执项 其中shape是数据形状 ''' def bias_variable(shape):   initial = tf.constant(0.1, shape=shape)   return tf.Variable(initial)  def conv2d(x, W):   return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')  def max_pool_2x2(x):   return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],  strides=[1, 2, 2, 1], padding='SAME')  W_conv1 = weight_variable([5, 5, 1, 32]) b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1, 28, 28, 1])  h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) h_pool1 = max_pool_2x2(h_conv1)  W_conv2 = weight_variable([5, 5, 32, 64]) b_conv2 = bias_variable([64])  h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_2x2(h_conv2)   W_fc1 = weight_variable([7 * 7 * 64, 1024]) b_fc1 = bias_variable([1024])  h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)  keep_prob = tf.placeholder("float") h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)  W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10])  y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)  cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  # 保存网络训练的参数 saver = tf.train.Saver() sess.run(tf.initialize_all_variables()) for i in range(8000):  batch = mnist.train.next_batch(50)  if i%100 == 0:   train_accuracy = accuracy.eval(feed_dict={     x:batch[0], y_: batch[1], keep_prob: 1.0})   print "step %d, training accuracy %g"%(i, train_accuracy)  train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})  save_path = saver.save(sess, "model_mnist.ckpt") print("Model saved in life:", save_path)  print "test accuracy %g"%accuracy.eval(feed_dict={   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

其中input_data.py如下代码,是进行mnist数据集的下载的:代码是由mnist数据集提供的官方下载的版本。

# Copyright 2015 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # #   http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions for downloading and reading MNIST data.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import gzip import os import tensorflow.python.platform import numpy from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' def maybe_download(filename, work_directory):  """Download the data from Yann's website, unless it's already here."""  if not os.path.exists(work_directory):   os.mkdir(work_directory)  filepath = os.path.join(work_directory, filename)  if not os.path.exists(filepath):   filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)   statinfo = os.stat(filepath)   print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')  return filepath def _read32(bytestream):  dt = numpy.dtype(numpy.uint32).newbyteorder('>')  return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] def extract_images(filename):  """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""  print('Extracting', filename)  with gzip.open(filename) as bytestream:   magic = _read32(bytestream)   if magic != 2051:    raise ValueError(      'Invalid magic number %d in MNIST image file: %s' %      (magic, filename))   num_images = _read32(bytestream)   rows = _read32(bytestream)   cols = _read32(bytestream)   buf = bytestream.read(rows * cols * num_images)   data = numpy.frombuffer(buf, dtype=numpy.uint8)   data = data.reshape(num_images, rows, cols, 1)   return data def dense_to_one_hot(labels_dense, num_classes=10):  """Convert class labels from scalars to one-hot vectors."""  num_labels = labels_dense.shape[0]  index_offset = numpy.arange(num_labels) * num_classes  labels_one_hot = numpy.zeros((num_labels, num_classes))  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1  return labels_one_hot def extract_labels(filename, one_hot=False):  """Extract the labels into a 1D uint8 numpy array [index]."""  print('Extracting', filename)  with gzip.open(filename) as bytestream:   magic = _read32(bytestream)   if magic != 2049:    raise ValueError(      'Invalid magic number %d in MNIST label file: %s' %      (magic, filename))   num_items = _read32(bytestream)   buf = bytestream.read(num_items)   labels = numpy.frombuffer(buf, dtype=numpy.uint8)   if one_hot:    return dense_to_one_hot(labels)   return labels class DataSet(object):  def __init__(self, images, labels, fake_data=False, one_hot=False,         dtype=tf.float32):   """Construct a DataSet.   one_hot arg is used only if fake_data is true. `dtype` can be either   `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into   `[0, 1]`.   """   dtype = tf.as_dtype(dtype).base_dtype   if dtype not in (tf.uint8, tf.float32):    raise TypeError('Invalid image dtype %r, expected uint8 or float32' %dtype)   if fake_data:    self._num_examples = 10000    self.one_hot = one_hot   else:    assert images.shape[0] == labels.shape[0], (      'images.shape: %s labels.shape: %s' % (images.shape,  labels.shape))    self._num_examples = images.shape[0]    # Convert shape from [num examples, rows, columns, depth]    # to [num examples, rows*columns] (assuming depth == 1)    assert images.shape[3] == 1    images = images.reshape(images.shape[0],    images.shape[1] * images.shape[2])    if dtype == tf.float32:     # Convert from [0, 255] -> [0.0, 1.0].     images = images.astype(numpy.float32)     images = numpy.multiply(images, 1.0 / 255.0)   self._images = images   self._labels = labels   self._epochs_completed = 0   self._index_in_epoch = 0  @property  def images(self):   return self._images  @property  def labels(self):   return self._labels  @property  def num_examples(self):   return self._num_examples  @property  def epochs_completed(self):   return self._epochs_completed  def next_batch(self, batch_size, fake_data=False):   """Return the next `batch_size` examples from this data set."""   if fake_data:    fake_image = [1] * 784    if self.one_hot:     fake_label = [1] + [0] * 9    else:     fake_label = 0    return [fake_image for _ in xrange(batch_size)], [      fake_label for _ in xrange(batch_size)]   start = self._index_in_epoch   self._index_in_epoch += batch_size   if self._index_in_epoch > self._num_examples:    # Finished epoch    self._epochs_completed += 1    # Shuffle the data    perm = numpy.arange(self._num_examples)    numpy.random.shuffle(perm)    self._images = self._images[perm]    self._labels = self._labels[perm]    # Start next epoch    start = 0    self._index_in_epoch = batch_size    assert batch_size <= self._num_examples   end = self._index_in_epoch   return self._images[start:end], self._labels[start:end] def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):  class DataSets(object):   pass  data_sets = DataSets()  if fake_data:   def fake():    return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)   data_sets.train = fake()   data_sets.validation = fake()   data_sets.test = fake()   return data_sets  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'  VALIDATION_SIZE = 5000  local_file = maybe_download(TRAIN_IMAGES, train_dir)  train_images = extract_images(local_file)  local_file = maybe_download(TRAIN_LABELS, train_dir)  train_labels = extract_labels(local_file, one_hot=one_hot)  local_file = maybe_download(TEST_IMAGES, train_dir)  test_images = extract_images(local_file)  local_file = maybe_download(TEST_LABELS, train_dir)  test_labels = extract_labels(local_file, one_hot=one_hot)  validation_images = train_images[:VALIDATION_SIZE]  validation_labels = train_labels[:VALIDATION_SIZE]  train_images = train_images[VALIDATION_SIZE:]  train_labels = train_labels[VALIDATION_SIZE:]  data_sets.train = DataSet(train_images, train_labels, dtype=dtype)  data_sets.validation = DataSet(validation_images, validation_labels,      dtype=dtype)  data_sets.test = DataSet(test_images, test_labels, dtype=dtype)  return data_sets 

然后进行代码的测试:

# import modules import sys import tensorflow as tf from PIL import Image, ImageFilter   def predictint(imvalue):   """   This function returns the predicted integer.   The imput is the pixel values from the imageprepare() function.   """    # Define the model (same as when creating the model file)   x = tf.placeholder(tf.float32, [None, 784])   W = tf.Variable(tf.zeros([784, 10]))   b = tf.Variable(tf.zeros([10]))    def weight_variable(shape):     initial = tf.truncated_normal(shape, stddev=0.1)     return tf.Variable(initial)    def bias_variable(shape):     initial = tf.constant(0.1, shape=shape)     return tf.Variable(initial)    def conv2d(x, W):     return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')    def max_pool_2x2(x):     return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')    W_conv1 = weight_variable([5, 5, 1, 32])   b_conv1 = bias_variable([32])    x_image = tf.reshape(x, [-1, 28, 28, 1])   h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)   h_pool1 = max_pool_2x2(h_conv1)    W_conv2 = weight_variable([5, 5, 32, 64])   b_conv2 = bias_variable([64])    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)   h_pool2 = max_pool_2x2(h_conv2)    W_fc1 = weight_variable([7 * 7 * 64, 1024])   b_fc1 = bias_variable([1024])    h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])   h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)    keep_prob = tf.placeholder(tf.float32)   h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)    W_fc2 = weight_variable([1024, 10])   b_fc2 = bias_variable([10])    y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)    init_op = tf.initialize_all_variables()   saver = tf.train.Saver()    """   Load the model_mnist.ckpt file   file is stored in the same directory as this python script is started   Use the model to predict the integer. Integer is returend as list.   Based on the documentatoin at   https://www.tensorflow.org/versions/master/how_tos/variables/index.html   """   with tf.Session() as sess:     sess.run(init_op)     saver.restore(sess, "model_mnist.ckpt")     # print ("Model restored.")      prediction = tf.argmax(y_conv, 1)     return prediction.eval(feed_dict={x: [imvalue], keep_prob: 1.0}, session=sess)   def imageprepare(argv):   """   This function returns the pixel values.   The imput is a png file location.   """   im = Image.open(argv).convert('L')   width = float(im.size[0])   height = float(im.size[1])   newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels    if width > height: # check which dimension is bigger     # Width is bigger. Width becomes 20 pixels.     nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width     if (nheight == 0): # rare case but minimum is 1 pixel       nheigth = 1       # resize and sharpen     img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)     wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition     newImage.paste(img, (4, wtop)) # paste resized image on white canvas   else:     # Height is bigger. Heigth becomes 20 pixels.     nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height     if (nwidth == 0): # rare case but minimum is 1 pixel       nwidth = 1       # resize and sharpen     img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)     wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition     newImage.paste(img, (wleft, 4)) # paste resized image on white canvas    # newImage.save("sample.png")    tv = list(newImage.getdata()) # get pixel values    # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.   tva = [(255 - x) * 1.0 / 255.0 for x in tv]   return tva   # print(tva)   def main(argv):   """   Main function.   """   imvalue = imageprepare(argv)   predint = predictint(imvalue)   print (predint[0]) # first value in list   if __name__ == "__main__":   main('2.png') 

其中我用于测试的代码如下:


可以将图片另存到路径下面,然后进行测试。

(1)载入我的手写数字的图像。
(2)将图像转换为黑白(模式“L”)
(3)确定原始图像的尺寸是最大的
(4)调整图像的大小,使得最大尺寸(醚的高度及宽度)为20像素,并且以相同的比例最小化尺寸刻度。
(5)锐化图像。这会极大地强化结果。
(6)把图像粘贴在28×28像素的白色画布上。在最大的尺寸上从顶部或侧面居中图像4个像素。最大尺寸始终是20个像素和4 + 20 + 4 = 28,最小尺寸被定位在28和缩放的图像的新的大小之间差的一半。
(7)获取新的图像(画布+居中的图像)的像素值。
(8)归一化像素值到0和1之间的一个值(这也在TensorFlow MNIST教程中完成)。其中0是白色的,1是纯黑色。从步骤7得到的像素值是与之相反的,其中255是白色的,0黑色,所以数值必须反转。下述公式包括反转和规格化(255-X)* 1.0 / 255.0


  • 上一条:
    30秒轻松实现TensorFlow物体检测
    下一条:
    基于循环神经网络(RNN)实现影评情感分类
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 2024.07.09日OpenAI将终止对中国等国家和地区API服务(0个评论)
    • 2024/6/9最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 国外服务器实现api.openai.com反代nginx配置(0个评论)
    • 2024/4/28最新免费公益节点SSR/V2ray/Shadowrocket/Clash节点分享|科学上网|免费梯子(1个评论)
    • 近期文章
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf文件功能(0个评论)
    • Laravel从Accel获得5700万美元A轮融资(0个评论)
    • 在go + gin中gorm实现指定搜索/区间搜索分页列表功能接口实例(0个评论)
    • 在go语言中实现IP/CIDR的ip和netmask互转及IP段形式互转及ip是否存在IP/CIDR(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2017-07
    • 2017-08
    • 2017-09
    • 2018-01
    • 2018-07
    • 2018-08
    • 2018-09
    • 2018-12
    • 2019-01
    • 2019-02
    • 2019-03
    • 2019-04
    • 2019-05
    • 2019-06
    • 2019-07
    • 2019-08
    • 2019-09
    • 2019-10
    • 2019-11
    • 2019-12
    • 2020-01
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2020-07
    • 2020-08
    • 2020-09
    • 2020-10
    • 2020-11
    • 2021-04
    • 2021-05
    • 2021-06
    • 2021-07
    • 2021-08
    • 2021-09
    • 2021-10
    • 2021-12
    • 2022-01
    • 2022-02
    • 2022-03
    • 2022-04
    • 2022-05
    • 2022-06
    • 2022-07
    • 2022-08
    • 2022-09
    • 2022-10
    • 2022-11
    • 2022-12
    • 2023-01
    • 2023-02
    • 2023-03
    • 2023-04
    • 2023-05
    • 2023-06
    • 2023-07
    • 2023-08
    • 2023-09
    • 2023-10
    • 2023-12
    • 2024-02
    • 2024-04
    • 2024-05
    • 2024-06
    • 2025-02
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客