CRNN学习笔记

    xiaoxiao2022-07-05  158

     

    最近学习了CRNN网络,大体训练流程如下:

    1、准备输入数据和标签,标签为稀疏矩阵

    inputs = tf.placeholder(tf.float32, [batch_size, input_height, input_width, 1]) label = tf.sparse_placeholder(tf.int32, name='label') seq_len = tf.placeholder(tf.int32, [None], name='seq_len')

    2、通过CNN网络提取特征

    cnn_out = self._cnn(inputs)

    3、通过2次双向RNN,得到神经网络输出结果

    crnn_model = self._rnn(cnn_out, self._seq_len)

    4、根据最终字符的类别得到最终的输出

    logits = tf.reshape(crnn_model, [-1, 512]) W = tf.Variable(tf.truncated_normal([512, self._class_num], stddev=0.1), name="W") b = tf.Variable(tf.constant(0., shape=[self._class_num]), name="b") logits = tf.matmul(logits, W) + b logits = tf.reshape(logits, [self._batch_size, -1, self._class_num]) # 网络层输出 net_output = tf.transpose(logits, (1, 0, 2))

    5、解析网络输出,其中decoded[0]是一个稀疏张量,类型和label一样

    decoded, log_prob = tf.nn.ctc_greedy_decoder(net_output, self._seq_len)

    6、损失函数loss

    with tf.name_scope('loss'): loss = tf.nn.ctc_loss(self._label, self._net_output, self._seq_len) loss = tf.reduce_mean(loss)

    7、优化器optimizer

    with tf.name_scope('optimizer'): train_op = tf.train.AdamOptimizer(self._learning_rate).minimize(loss)

    8、准确率accuracy

    with tf.name_scope('accuracy'): accuracy = 1 - tf.reduce_mean(tf.edit_distance(tf.cast(self._decoded[0], tf.int32), self._label)) accuracy_broad = tf.summary.scalar("accuracy", accuracy)

    9、喂数据进行训练

    feed_dict = {self._inputs: batch_data,self._label: batch_label, \ self._seq_len: [self._max_char_count] * self.batch_size} sess.run(train_op, feed_dict=feed_dict)

     

    最新回复(0)