利用tensorflow提供的tfrecord数据存储格式工具,我们可以将我们已经进行过处理的数据保存起来,以便我们下次更高效地读取,略过数据处理的过程,提高效率。具体的步骤大概分为以下几步:
将数据转化为tf.train.Feature,然后存于字典;接着,将其转化为tf.train.example,然后进行序列化,写入tf.python_io.TFRecordWriter,到这里就完成了写入的操作;读取的时候,首先通过tf.data.TFRecordDataset来读取它;然后,通过tf.parse_single_example来解析还原数据原本的结构;最后就可以结合我们上次提高的batch来批量获取。(关于batch、shuffle、repeat函数的详细介绍TensorFlow dataset.shuffle、batch、repeat用法)
写入TFRecordWriter
import tensorflow
as tf
import collections
import numpy
as np
inputs_1
= np
.array
([
[[1,2], [3,4]],
[[5,6], [7,8]]
])
inputs_2
= [
[1.1, 2.2, 3.3],
[4.4, 5.5, 6.6]
]
lables
= [0, 1]
def create_int_feature(values
):
f
= tf
.train
.Feature
(int64_list
=tf
.train
.Int64List
(value
=list(values
)))
return f
def create_float_feature(values
):
f
= tf
.train
.Feature
(float_list
=tf
.train
.FloatList
(value
=list(values
)))
return f
def create_bytes_feature(values
):
f
= tf
.train
.Feature
(bytes_list
=tf
.train
.BytesList
(value
=[values
]))
return f
writer
= tf
.python_io
.TFRecordWriter
('test.tfrecord')
for i1
, i2
, l
in zip(inputs_1
, inputs_2
, lables
):
features
= collections
.OrderedDict
()
features
['inputs_1'] = create_bytes_feature
(i1
.tostring
())
features
['inputs_2'] = create_float_feature
(i2
)
features
['labels'] = create_int_feature
([l
])
example
= tf
.train
.Example
(features
=tf
.train
.Features
(feature
=features
))
writer
.write
(example
.SerializeToString
())
writer
.close
()
读取并解析为dataset
name_to_features
= {
"inputs_1": tf
.FixedLenFeature
([], tf
.string
),
"inputs_2": tf
.FixedLenFeature
([3], tf
.float32
),
"labels": tf
.FixedLenFeature
([], tf
.int64
)
}
d
= tf
.data
.TFRecordDataset
('test.tfrecord')
d
= d
.repeat
()
d
= d
.shuffle
(buffer_size
=2)
d
= d
.apply(tf
.contrib
.data
.map_and_batch
(
lambda record
: tf
.parse_single_example
(record
, name_to_features
),
batch_size
=1))
iters
= d
.make_one_shot_iterator
()
batch
= iters
.get_next
()
inputs_1_batch
= tf
.decode_raw
(batch
['inputs_1'], tf
.int32
)
inputs_1_batch
= tf
.reshape
(inputs_1_batch
, [-1, 2, 2])
inputs_2_batch
= batch
['inputs_2']
labels_batch
= batch
['labels']
sess
= tf
.Session
()
sess
.run
([inputs_1_batch
, inputs_2_batch
, labels_batch
])
欢迎关注同名公众号:“我就算饿死也不做程序员”。 交个朋友,一起交流,一起学习,一起进步。