NLP(九)Attention

    xiaoxiao2025-07-09  13

    Attention模块需要使用keras的自定义写法 简要的说Attention模块时将n个时刻的LSTM输出结合算出一个向量输入到下一个RNN中 自己之前在看恩达的课程的时候,画了张图

    class AttentionLayer(Layer): def __init__(self, **kwargs): self.init = initializations.get('normal') super(AttentionLayer, self).__init__(**kwargs) def build(self, input_shape): self.W = self.init((input_shape[-1],)) self.trainable_weights = [self.W] super(AttLayer, self).build(input_shape) def call(self, x, mask=None): e = K.tanh(K.dot(x, self.W) ai = K.exp(e) weights = ai/K.sum(ai, axis=1).dimshuffle(0,'x') weighted_input = x*weights.dimshuffle(0,1,'x') return weighted_input.sum(axis=1) def get_output_shape_for(self, input_shape): return (input_shape[0], input_shape[-1]) # 基本使用和其他的layer一致 l_lstm = Bidirectional(LSTM(100, return_sequences=True))(embedded_seq) attenion= AttentionLayer()(l_lstm)
    最新回复(0)