文本分类:BiRNN+Attention(tensorflow2.0实现)

BiRNN+Attention

完整代码在github
此处对于注意力机制的实现参照了论文 Feed-Forward Networks with Attention Can Solve Some Long-Term Memory Problems

此处实现的网络结构:

基于tensorflow2.0的keras实现

自定义 Attention layer

这是tensorflow2.0推荐的写法,继承Layer,自定义Layer

需要注意的几点:

  • 如果需要使用到其他Layer结构或者Sequential结构,需要在init()函数里赋值
  • 在build()里面构建权重参数, 每个参数需要赋值name
    • 如果参数不给name,当训练到第2个epoch时会报错:AttributeError: ‘NoneType’ object has no attribute ‘replace’
  • 在call()里写计算逻辑
  • 这里实现的Attention是将GRU各个step的output作为key和value,增加一个参数向量W作为query,主要是为了计算GRU各个step的output的权重,最后加权求和得到Attention的输出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
# @Time : 2020/4/21 13:55
# @Author : zdqzyx
# @File : attention.py
# @Software: PyCharm

from tensorflow.keras import initializers,regularizers,constraints
from tensorflow.keras.layers import Layer
import tensorflow as tf

class Attention(Layer):
def __init__(self,
W_regularizer=None,
b_regularizer=None,
W_constraint=None,
b_constraint=None,
bias=True,
**kwargs
):
"""
Keras Layer that implements an Attention mechanism for temporal data.
Supports Masking.
Follows the work of Raffel et al. [https://arxiv.org/abs/1512.08756]
# Input shape
3D tensor with shape: `(samples, steps, features)`.
# Output shape
2D tensor with shape: `(samples, features)`.
:param kwargs:
Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
The dimensions are inferred based on the output shape of the RNN.
Example:
# 1
model.add(LSTM(64, return_sequences=True))
model.add(Attention())
# next add a Dense layer (for classification/regression) or whatever...
# 2
hidden = LSTM(64, return_sequences=True)(words)
sentence = Attention()(hidden)
# next add a Dense layer (for classification/regression) or whatever...
"""
super(Attention, self).__init__()
self.bias = bias
self.init = initializers.get('glorot_uniform')

def build(self, input_shape):
'''
:param input_shape:
:return:
'''
self.output_dim = input_shape[-1]
self.W = self.add_weight(
name='{}_W'.format(self.name),
shape=(input_shape[2], 1),
initializer=self.init,
trainable=True
)
if self.bias:
self.b = self.add_weight(
name='{}_b'.format(self.name),
shape=(input_shape[1], 1),
initializer='zero',
trainable=True
)
else:
self.b = None

self.built = True

def compute_mask(self, inputs, mask=None):
return None

def call(self, inputs, mask=None):
# (N, step, d), (d, 1) ==> (N, step, 1)
e = tf.matmul(inputs, self.W, )
if self.bias:
e += self.b
e = tf.tanh(e)
a = tf.nn.softmax(e, axis=1)
# (N, step, d) (N, step, 1) ====> (N, step, d)
c = inputs*a
# (N, d)
c = tf.reduce_sum(c, axis=1)
return c

def get_config(self):
return {'units': self.output_dim}


if __name__=='__main__':
x = tf.ones((2, 5, 10))
att = Attention()
y = att(x)
print(y.shape)
print(y)
print(att.get_config())

自定义Model 构建

  • 其中可以注意的是:允许定义Sequential来包裹常用block,比如下面的 point_wise_feed_forward_network()函数,包裹了n个全连接层。然后在自定义模型的init()里初始化使用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# -*- coding: utf-8 -*-
# @Time : 2020/4/21 13:50
# @Author : zdqzyx
# @File : text_birnn_att.py
# @Software: PyCharm


import tensorflow as tf
from tensorflow.keras.layers import Embedding, Dense, GRU, Bidirectional
from tensorflow.keras import Model
from attention import Attention

def point_wise_feed_forward_network(dense_size):
ffn = tf.keras.Sequential()
for size in dense_size:
ffn.add(Dense(size, activation='relu'))
return ffn

class TextBiRNNAtt(Model):

def __init__(self,
maxlen,
max_features,
embedding_dims,
class_num,
last_activation='softmax',
dense_size=None
):
'''
:param maxlen: 文本最大长度
:param max_features: 词典大小
:param embedding_dims: embedding维度大小
:param class_num:
:param last_activation:
'''
super(TextBiRNNAtt, self).__init__()
self.maxlen = maxlen
self.max_features = max_features
self.embedding_dims = embedding_dims
self.class_num = class_num
self.last_activation = last_activation
self.dense_size = dense_size

self.embedding = Embedding(input_dim=self.max_features, output_dim=self.embedding_dims, input_length=self.maxlen)
self.bi_rnn = Bidirectional(layer=GRU(units=128, activation='tanh', return_sequences=True), merge_mode='concat' ) # LSTM or GRU
self.attention = Attention()
if self.dense_size is not None:
self.ffn = point_wise_feed_forward_network(dense_size)
self.classifier = Dense(self.class_num, activation=self.last_activation)

def call(self, inputs, training=None, mask=None):
if len(inputs.get_shape()) != 2:
raise ValueError('The rank of inputs of TextBiRNNAtt must be 2, but now is {}'.format(inputs.get_shape()))
if inputs.get_shape()[1] != self.maxlen:
raise ValueError('The maxlen of inputs of TextBiRNNAtt must be %d, but now is %d' % (self.maxlen, inputs.get_shape()[1]))

emb = self.embedding(inputs)
x = self.bi_rnn(emb)
x = self.attention(x)
if self.dense_size is not None:
x = self.ffn(x)
output = self.classifier(x)
return output

def build_graph(self, input_shape):
input_shape_nobatch = input_shape[1:]
self.build(input_shape)
inputs = tf.keras.Input(shape=input_shape_nobatch)
if not hasattr(self, 'call'):
raise AttributeError("User should define 'call' method in sub-class model!")
_ = self.call(inputs)

if __name__=='__main__':
model = TextBiRNNAtt(maxlen=400,
max_features=5000,
embedding_dims=100,
class_num=2,
last_activation='softmax',
# dense_size=[128, 64],
dense_size = None
)
model.build_graph(input_shape=(None, 400))
model.summary()