文本分类:TextCNN(tensorflow2.0实现)

TextRNN

完整代码在github
TextCNN原始论文: Convolutional Neural Networks for Sentence Classification

TextCNN 的网络结构:

基于tensorflow2.0的keras实现

自定义model

这是tensorflow2.0推荐的写法,继承Model,使模型子类化

需要注意的几点:

  • 如果需要使用到其他Layer结构或者Sequential结构,需要在init()函数里赋值
  • 在model没有fit前,想调用summary函数时显示模型各层shape时,则需要自定义一个函数去build下模型,类似下面代码中的build_graph函数
  • summary()显示shape顺序,是按照init()里layer赋值的顺序
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# -*- coding: utf-8 -*-
# @Time : 2020/4/20 14:44
# @Author : zdqzyx
# @File : textcnn.py
# @Software: PyCharm

import tensorflow as tf
from tensorflow.keras.layers import Embedding, Conv1D, GlobalAveragePooling1D, Dense, Concatenate, GlobalMaxPooling1D
from tensorflow.keras import Model

class TextCNN(Model):

def __init__(self,
maxlen,
max_features,
embedding_dims,
class_num,
kernel_sizes=[1,2,3],
kernel_regularizer=None,
last_activation='softmax'
):
'''
:param maxlen: 文本最大长度
:param max_features: 词典大小
:param embedding_dims: embedding维度大小
:param kernel_sizes: 滑动卷积窗口大小的list, eg: [1,2,3]
:param kernel_regularizer: eg: tf.keras.regularizers.l2(0.001)
:param class_num:
:param last_activation:
'''
super(TextCNN, self).__init__()
self.maxlen = maxlen
self.kernel_sizes = kernel_sizes
self.class_num = class_num
self.embedding = Embedding(input_dim=max_features, output_dim=embedding_dims, input_length=maxlen)
self.conv1s = []
self.avgpools = []
for kernel_size in kernel_sizes:
self.conv1s.append(Conv1D(filters=128, kernel_size=kernel_size, activation='relu', kernel_regularizer=kernel_regularizer))
self.avgpools.append(GlobalMaxPooling1D())
self.classifier = Dense(class_num, activation=last_activation, )

def call(self, inputs, training=None, mask=None):
if len(inputs.get_shape()) != 2:
raise ValueError('The rank of inputs of TextCNN must be 2, but now is %d' % len(inputs.get_shape()))
if inputs.get_shape()[1] != self.maxlen:
raise ValueError('The maxlen of inputs of TextCNN must be %d, but now is %d' % (self.maxlen, inputs.get_shape()[1]))

emb = self.embedding(inputs)
conv1s = []
for i in range(len(self.kernel_sizes)):
c = self.conv1s[i](emb) # (batch_size, maxlen-kernel_size+1, filters)
c = self.avgpools[i](c) # # (batch_size, filters)
conv1s.append(c)
x = Concatenate()(conv1s) # (batch_size, len(self.kernel_sizes)*filters)
output = self.classifier(x)
return output

def build_graph(self, input_shape):
'''自定义函数,在调用model.summary()之前调用
'''
input_shape_nobatch = input_shape[1:]
self.build(input_shape)
inputs = tf.keras.Input(shape=input_shape_nobatch)
if not hasattr(self, 'call'):
raise AttributeError("User should define 'call' method in sub-class model!")
_ = self.call(inputs)

main

构建模型helper,帮助构建模型,以及定义管理各种回调函数

  • 其中主要回调函数有三个:EarlyStopping, TensorBoard, ModelCheckpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# -*- coding: utf-8 -*-
# @Time : 2020/4/20 14:43
# @Author : zdqzyx
# @File : main.py
# @Software: PyCharm

# ===================== set random ===========================
import numpy as np
import tensorflow as tf
import random as rn
np.random.seed(0)
rn.seed(0)
tf.random.set_seed(0)
# =============================================================

import os
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences
from textcnn import TextCNN


def checkout_dir(dir_path, do_delete=False):
import shutil
if do_delete and os.path.exists(dir_path):
shutil.rmtree(dir_path)
if not os.path.exists(dir_path):
print(dir_path, 'make dir ok')
os.makedirs(dir_path)


class ModelHepler:
def __init__(self, class_num, maxlen, max_features, embedding_dims, epochs, batch_size):
self.class_num = class_num
self.maxlen = maxlen
self.max_features = max_features
self.embedding_dims = embedding_dims
self.epochs = epochs
self.batch_size = batch_size
self.callback_list = []
print('Bulid Model...')
self.create_model()

def create_model(self):
model = TextCNN(maxlen=self.maxlen,
max_features=self.max_features,
embedding_dims=self.embedding_dims,
class_num=self.class_num,
kernel_sizes=[2,3,5],
kernel_regularizer=None,
last_activation='softmax')
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'],
)

model.build_graph(input_shape=(None, maxlen))
model.summary()
self.model = model

def get_callback(self, use_early_stop=True, tensorboard_log_dir='logs\\TextCNN-epoch-5', checkpoint_path="save_model_dir\\cp-moel.ckpt"):
callback_list = []
if use_early_stop:
# EarlyStopping
early_stopping = EarlyStopping(monitor='val_accuracy', patience=7, mode='max')
callback_list.append(early_stopping)
if checkpoint_path is not None:
# save model
checkpoint_dir = os.path.dirname(checkpoint_path)
checkout_dir(checkpoint_dir, do_delete=True)
# 创建一个保存模型权重的回调
cp_callback = ModelCheckpoint(filepath=checkpoint_path,
monitor='val_accuracy',
mode='max',
save_best_only=True,
save_weights_only=True,
verbose=1,
period=2,
)
callback_list.append(cp_callback)
if tensorboard_log_dir is not None:
# tensorboard --logdir logs/TextCNN-epoch-5
checkout_dir(tensorboard_log_dir, do_delete=True)
tensorboard_callback = TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1)
callback_list.append(tensorboard_callback)
self.callback_list = callback_list

def fit(self, x_train, y_train, x_val, y_val):
print('Train...')
self.model.fit(x_train, y_train,
batch_size=self.batch_size,
epochs=self.epochs,
verbose=2,
callbacks=self.callback_list,
validation_data=(x_val, y_val))

def load_model(self, checkpoint_path):
checkpoint_dir = os.path.dirname((checkpoint_path))
latest = tf.train.latest_checkpoint(checkpoint_dir)
print('restore model name is : ', latest)
# 创建一个新的模型实例
# model = self.create_model()
# 加载以前保存的权重
self.model.load_weights(latest)

# ================ params =========================
class_num = 2
maxlen = 400
embedding_dims = 200
epochs = 10
batch_size = 128
max_features = 5000

MODEL_NAME = 'TextCNN-epoch-10-emb-200'

use_early_stop=True
tensorboard_log_dir = 'logs\\{}'.format(MODEL_NAME)
# checkpoint_path = "save_model_dir\\{}\\cp-{epoch:04d}.ckpt".format(MODEL_NAME, '')
checkpoint_path = 'save_model_dir\\'+MODEL_NAME+'\\cp-{epoch:04d}.ckpt'
# ====================================================================

print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print('Pad sequences (samples x time)...')
x_train = pad_sequences(x_train, maxlen=maxlen, padding='post')
x_test = pad_sequences(x_test, maxlen=maxlen, padding='post')
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)

model_hepler = ModelHepler(class_num=class_num,
maxlen=maxlen,
max_features=max_features,
embedding_dims=embedding_dims,
epochs=epochs,
batch_size=batch_size
)
model_hepler.get_callback(use_early_stop=use_early_stop, tensorboard_log_dir=tensorboard_log_dir, checkpoint_path=checkpoint_path)
model_hepler.fit(x_train=x_train, y_train=y_train, x_val=x_test, y_val=y_test)
print('Test...')
result = model_hepler.model.predict(x_test)
test_score = model_hepler.model.evaluate(x_test, y_test,
batch_size=batch_size)
print("test loss:", test_score[0], "test accuracy", test_score[1])



model_hepler = ModelHepler(class_num=class_num,
maxlen=maxlen,
max_features=max_features,
embedding_dims=embedding_dims,
epochs=epochs,
batch_size=batch_size
)
model_hepler.load_model(checkpoint_path=checkpoint_path)
# 重新评估模型
loss, acc = model_hepler.model.evaluate(x_test, y_test, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))