CRNN代码解析

    xiaoxiao2022-07-04  124

    CRNN主要分为四步 1.特征提取 2.序列转换 3.执行LSTM获取序列输出 4.进行CTC转换

    CRNN使用

    以下代码有三个作用 1.特征提取 2.序列转换 3.执行LSTM获取序列输出

    def inference(self, inputdata, name, reuse=False): """ Main routine to construct the network :param inputdata: :param name: :param reuse: :return: """ with tf.variable_scope(name_or_scope=name, reuse=reuse): # centerlized data inputdata = tf.divide(inputdata, 255.0) #1.特征提取阶段 # first apply the cnn feature extraction stage cnn_out = self._feature_sequence_extraction( inputdata=inputdata, name='feature_extraction_module' ) #2.第二步, batch*1*25*512 变成 batch * 25 * 512 # second apply the map to sequence stage sequence = self._map_to_sequence( inputdata=cnn_out, name='map_to_sequence_module' ) #第三步,应用序列标签阶段 # third apply the sequence label stage # net_out width, batch, n_classes # raw_pred width, batch, 1 net_out, raw_pred = self._sequence_label( inputdata=sequence, name='sequence_rnn_module' ) return net_out

    以下代码进行CTC解码

    train_decoded, train_log_prob = tf.nn.ctc_beam_search_decoder( train_inference_ret, CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE), merge_repeated=False )

    下面我们详细看看各个部分是如何实现的

    1.特征提取层

    使用的VGG提取特征 输入是batchsize321003 batchsizehwc 输出是batchsize125512 batchsizehwc

    def _feature_sequence_extraction(self, inputdata, name): """ Implements section 2.1 of the paper: "Feature Sequence Extraction" :param inputdata: eg. batch*32*100*3 NHWC format :param name: :return: _conv_stage:conv + bn + relu + max_pool """ with tf.variable_scope(name_or_scope=name): # batch*32*100*3 conv1 = self._conv_stage( inputdata=inputdata, out_dims=64, name='conv1' ) #batch*16*50*64 conv2 = self._conv_stage( inputdata=conv1, out_dims=128, name='conv2' ) # batch*8*25*128 conv3 = self.conv2d( inputdata=conv2, out_channel=256, kernel_size=3, stride=1, use_bias=False, name='conv3' ) # batch*8*25*256 bn3 = self.layerbn( inputdata=conv3, is_training=self._is_training, name='bn3' ) relu3 = self.relu( inputdata=bn3, name='relu3' ) conv4 = self.conv2d( inputdata=relu3, out_channel=256, kernel_size=3, stride=1, use_bias=False, name='conv4' ) # batch*8*25*256 bn4 = self.layerbn( inputdata=conv4, is_training=self._is_training, name='bn4' ) relu4 = self.relu( inputdata=bn4, name='relu4') max_pool4 = self.maxpooling( inputdata=relu4, kernel_size=[2, 1], stride=[2, 1], padding='VALID', name='max_pool4' ) # batch*4*25*256 conv5 = self.conv2d( inputdata=max_pool4, out_channel=512, kernel_size=3, stride=1, use_bias=False, name='conv5' ) # batch*4*25*512 bn5 = self.layerbn( inputdata=conv5, is_training=self._is_training, name='bn5' ) relu5 = self.relu( inputdata=bn5, name='bn5' ) conv6 = self.conv2d( inputdata=relu5, out_channel=512, kernel_size=3, stride=1, use_bias=False, name='conv6' ) # batch*4*25*512 bn6 = self.layerbn( inputdata=conv6, is_training=self._is_training, name='bn6' ) relu6 = self.relu( inputdata=bn6, name='relu6' ) max_pool6 = self.maxpooling( inputdata=relu6, kernel_size=[2, 1], stride=[2, 1], name='max_pool6' ) # batch*2*25*512 conv7 = self.conv2d( inputdata=max_pool6, out_channel=512, kernel_size=2, stride=[2, 1], use_bias=False, name='conv7' ) # batch*1*25*512 bn7 = self.layerbn( inputdata=conv7, is_training=self._is_training, name='bn7' ) relu7 = self.relu( inputdata=bn7, name='bn7' ) #return batch*1*25*512 return relu7

    2.特征转换为序列

    提取的特征LSTM不能直接使用,需要先进序列转换 输入batchsize125512 输出batchsize25*512

    def _map_to_sequence(self, inputdata, name): """ Implements the map to sequence part of the network. This is used to convert the CNN feature map to the sequence used in the stacked LSTM layers later on. Note that this determines the length of the sequences that the LSTM expects :param inputdata: :param name: :return: """ with tf.variable_scope(name_or_scope=name): shape = inputdata.get_shape().as_list() # H必须是一,这是LSTM网络输入的要求 assert shape[1] == 1 # H of the feature map must equal to 1 ret = self.squeeze(inputdata=inputdata, axis=1, name='squeeze') return ret

    其中用到了 self.squeeze函数,我们看看它做了什么

    def squeeze(inputdata, axis=None, name=None): """ :param inputdata: :param axis: :param name: :return: """ return tf.squeeze(input=inputdata, axis=axis, name=name)

    3.进行LSTM获取输出序列

    输入batchsize125*512 输出width,batchn_classes

    def _sequence_label(self, inputdata, name): """ Implements the sequence label part of the network :param inputdata: :param name: :return: """ with tf.variable_scope(name_or_scope=name): # construct stack lstm rcnn layer # forward lstm cell fw_cell_list = [tf.nn.rnn_cell.LSTMCell(nh, forget_bias=1.0) for nh in [self._hidden_nums] * self._layers_nums] # Backward direction cells bw_cell_list = [tf.nn.rnn_cell.LSTMCell(nh, forget_bias=1.0) for nh in [self._hidden_nums] * self._layers_nums] stack_lstm_layer, _, _ = rnn.stack_bidirectional_dynamic_rnn( fw_cell_list, bw_cell_list, inputdata, dtype=tf.float32 ) stack_lstm_layer = self.dropout( inputdata=stack_lstm_layer, keep_prob=0.5, is_training=self._is_training, name='sequence_drop_out' ) [batch_s, _, hidden_nums] = inputdata.get_shape().as_list() # [batch, width, 2*n_hidden] shape = tf.shape(stack_lstm_layer) rnn_reshaped = tf.reshape(stack_lstm_layer, [shape[0] * shape[1], shape[2]]) w = tf.get_variable( name='w', shape=[hidden_nums, self._num_classes], initializer=tf.truncated_normal_initializer(stddev=0.02), trainable=True ) # Doing the affine projection logits = tf.matmul(rnn_reshaped, w, name='logits') logits = tf.reshape(logits, [shape[0], shape[1], self._num_classes], name='logits_reshape') raw_pred = tf.argmax(tf.nn.softmax(logits), axis=2, name='raw_prediction') # Swap batch and batch axis rnn_out = tf.transpose(logits, [1, 0, 2], name='transpose_time_major') # [width, batch, n_classes] return rnn_out, raw_pred
    最新回复(0)