使用TensorFlow提供的slim模型来训练数据模型供iOS使用

    xiaoxiao2022-07-15  129

    1、下载slim模型包:

    cd /Users/javalong/Download git clone https://github.com/tensorflow/models/

    2、  数据可以是slim提供的数据集或者是自己采集的图片

    2.1、下载slim提供的数据集flowers

    2.1.1、设置下载目录命令:

    DATA_DIR=/Users/javalong/Desktop/Test/output/flowers

    2.1.2、进入到slim模型目录命令:

    cd /Users/javalong/Downloads/models-master/slim

    2.1.3、下载数据集命令:

    python3 download_and_convert_data.py \

        --dataset_name=flowers \

        --dataset_dir="${DATA_DIR}"

    2.1.4、查看目录下的文件命令:

    ls ${DATA_DIR}

    得到:

    flowers_train-00000-of-00005.tfrecord

    ...

    flowers_train-00004-of-00005.tfrecord

    flowers_validation-00000-of-00005.tfrecord

    ...

    flowers_validation-00004-of-00005.tfrecord

    labels.txt

    2.2、我们可以看到下载slim提供的数据文件是tfrecord格式,所以我们要训练自己采集的图片,第一步先将图片转换成tfrecord格式。

    2.2.1、将图片转换成TFRecord文件,需要安装的软件

    pip3 install Pillow

    pip3 install matplotlib

    2.2.2、在/Users/javalong/Downloads/models-master/slim下创建一个fu_img_to_tfrecord.py文件。

    如图:

    2.2.3、fu_img_to_tfrecord.py的内容为:

    import os import os.path import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt import sys import pprint pp = pprint.PrettyPrinter(indent = 2) data_dir=sys.argv[1] train_dir=sys.argv[2] classes=[] for dir in os.listdir(data_dir): path = os.path.join(data_dir, dir) if os.path.isdir(path): classes.append(dir) train= tf.python_io.TFRecordWriter(train_dir+"/iss_train.tfrecord") test= tf.python_io.TFRecordWriter(train_dir+"/iss_test.tfrecord") def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def image_to_tfexample(image_data, image_format, height, width, class_id): return tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': bytes_feature(image_data), 'image/format': bytes_feature(image_format), 'image/class/label': int64_feature(class_id), 'image/height': int64_feature(height), 'image/width': int64_feature(width), })) def get_extension(path): return os.path.splitext(path)[1] class ImageReader(object): """Helper class that provides TensorFlow image coding utilities.""" def __init__(self): # Initializes function that decodes RGB JPEG data. self._decode_jpeg_data = tf.placeholder(dtype=tf.string) self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) def read_image_dims(self, sess, image_data): image = self.decode_jpeg(sess, image_data) return image.shape[0], image.shape[1] def decode_jpeg(self, sess, image_data): image = sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data}) assert len(image.shape) == 3 assert image.shape[2] == 3 return image def write_label_file(labels_to_class_names, dataset_dir, filename='lables.txt'): """Writes a file with the list of class names. Args: labels_to_class_names: A map of (integer) labels to class names. dataset_dir: The directory in which the labels file should be written. filename: The filename where the class names are written. """ labels_filename = os.path.join(dataset_dir, filename) with tf.gfile.Open(labels_filename, 'w') as f: for label in labels_to_class_names: class_name = labels_to_class_names[label] f.write('%d:%s\n' % (label, class_name)) lable_file=train_dir+'/lable.txt' lable_input=open(lable_file, 'w') info_file=train_dir+'/meta_info.txt' test_num=0; train_num=0; with tf.Graph().as_default(): image_reader = ImageReader() with tf.Session('') as sess: for index,name in enumerate(classes): lable_input.write('%d:%s\n' % (index, name)) class_path=data_dir+'/'+name+'/' for num, img_name in enumerate(os.listdir(class_path)): img_path=class_path+img_name format=get_extension(img_name) image_data = tf.gfile.FastGFile(img_path, 'rb').read() height, width = image_reader.read_image_dims(sess, image_data) example = image_to_tfexample(image_data, b'jpg', height, width, index) if num % 5 == 0: test_num= test_num+1 #pass #print img_path + " " + str(index) + " " + name test.write(example.SerializeToString()) else: train_num=train_num+1 train.write(example.SerializeToString()) #print img_path + " " + str(index) + " " + name train.close() test.close() info_input=open(info_file,'w') info_input.write("train_num:"+str(train_num)+'\n') info_input.write("test_num:"+str(test_num)+'\n') info_input.close() lable_input.close()

    2.2.4、执行转换命令:

    python3 /Users/javalong/Downloads/models-master/slim/fu_img_to_tfrecord.py /Users/javalong/Desktop/flowers /Users/javalong/Desktop/flower_record

    注:

    2.2.5/Users/javalong/Desktop/flowers是存放采集的图片,如图:

    2.2.6/Users/javalong/Desktop/flower_record是生成的tfrecord格式文件存放目录。最终生成的文件如图:

    2.2.7使用/Users/javalong/Desktop/flowers目录的子目录名作为分类文本会存储到生成的label.txt中。如图:

    2.2.8fu_img_to_tfrecord.py功能实现参考/Users/javalong/Downloads/models-master/slim/datasets/download_and_convert_flowers.py文件

    3、用预训练数据集inception_v3来训练数据集flowers

    3.1、设置相应的目录:

    DATASET_DIR=/Users/javalong/Desktop/Test/output/flowers

    CHECKPOINT_PATH=/Users/javalong/Desktop/Test/output/inception/inception_v3.ckpt

    TRAIN_DIR=/Users/javalong/Desktop/Test/output/tran

    3.2、训练命令:

    python3 train_image_classifier.py \

        --train_dir=${TRAIN_DIR} \

        --dataset_dir=${DATASET_DIR} \

        --dataset_name=flowers \

        --dataset_split_name=train \

        --model_name=inception_v3 \

        --checkpoint_path=${CHECKPOINT_PATH} \

        --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \

        --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \

        --clone_on_cpu=true

    4、生成.pb文件

    4.1、在/Users/javalong/Downloads/models-master/slim下创建一个bbb.py文件。

    如图:

    4.2、bbb.py的内容为:

    import os import tensorflow as tf import tensorflow.contrib.slim as slim from nets import inception from nets import inception_v1 from nets import inception_v3 from nets import nets_factory from tensorflow.python.framework import graph_util from tensorflow.python.platform import gfile from google.protobuf import text_format checkpoint_path = tf.train.latest_checkpoint('/Users/javalong/Desktop/Test/output/tran') with tf.Graph().as_default() as graph: input_tensor = tf.placeholder(tf.float32, shape=(None, 299, 299, 3), name='input_image') with tf.Session() as sess: # with tf.variable_scope('model') as scope: with slim.arg_scope(inception.inception_v3_arg_scope()): logits, end_points = inception.inception_v3(input_tensor, num_classes=5, is_training=False) saver = tf.train.Saver() saver.restore(sess, checkpoint_path) output_node_names = 'InceptionV3/Predictions/Reshape_1' input_graph_def = graph.as_graph_def() output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names.split(",")) with open('/Users/javalong/Desktop/Test/output/output_graph_nodes.txt', 'w') as f: f.write(text_format.MessageToString(output_graph_def)) output_graph = '/Users/javalong/Desktop/Test/output/inception_v3_final.pb' with gfile.FastGFile(output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString())

    5、优化模型并去掉iOS不支持的算子 

    查考此篇文章

    相关资源:敏捷开发V1.0.pptx
    最新回复(0)