任書瑋
Data Scientists Playground
9 min readDec 18, 2019

--

驗證碼辨識實作-part2 model

這篇會介紹用的 nn model, 主要思考方向是用 CNN 先處理過圖片, 再用CNN得到的feature 做後續處理, 在這部分我用了三個不同的方法, 包含 Lstm 和Transformer, 最後使用 CTC 的方始訓練, 如果對 CTC 不了解的可以先去參考

如果對Transformer有不了解的可以參考, 也可以參考我之前寫的文章

以下的程式基於tensorflow 1.14寫的

首先我們的batch 圖片讀進來的維度是(batch, 120, 20, 1), 這邊 Channel 只有1 的原因是前處理時直接把圖片變灰階, CNN的處理為

conv = tf.layers.conv2d(x, filters=8, kernel_size=5, strides=(1, 1), padding="same", activation=tf.nn.relu)
conv = tf.nn.max_pool(conv, ksize=3, strides=2, padding="VALID")
conv = tf.nn.lrn(conv)
conv = tf.layers.conv2d(conv, filters=16, kernel_size=3, strides=(1, 1), padding="same", activation=tf.nn.relu)
conv = tf.nn.max_pool(conv, ksize=2, strides=2, padding="VALID")
conv = tf.nn.lrn(conv)
conv = tf.layers.conv2d(conv, filters=32, kernel_size=3, strides=(1, 1), padding="same", activation=tf.nn.relu)
conv = tf.nn.max_pool(conv, ksize=2, strides=2, padding="VALID")
conv = tf.nn.lrn(conv)
conv = tf.layers.conv2d(conv, filters=32, kernel_size=3, strides=(1, 1), padding="same", activation=tf.nn.relu)
conv = tf.nn.max_pool(conv, ksize=2, strides=2, padding="VALID")
conv = tf.nn.lrn(conv)
conv = tf.layers.conv2d(conv, filters=32, kernel_size=3, strides=(1, 1), padding="same", activation=tf.nn.relu)
conv = tf.nn.max_pool(conv, ksize=2, strides=(2,1), padding="VALID")
conv = tf.nn.lrn(conv)

使得維度變成(batch, 3, 2, 32), 接下來我們先reshap處理

_, w, h, c = conv.get_shape().as_list()
enc = tf.reshape(conv, (-1, w * h, c))

變成維度(batch, 6, 32), 這裏的6就是我們的序列長度

接下來使用三種方法, 但最後維度都會是(batch, 6, 27), 這裏用27的原因是多了 blank 符號, numLabel 也就是27, 且blank 使用 index 0 , 這裏額外說一下label需要 padding 時也是用 index 0

首先是法一, 直接用dense

f = tf.layers.dense(enc, 64, activation='relu')
dec2word = tf.layers.dense(f, numLabel)

法二使用bi-Lstm, 最後把 output_fw, output_bw 結合在一起, 再dense

cell_fw = get_lstm_cell(64)
cell_bw = get_lstm_cell(64)
init_state_fw = cell_fw.zero_state(tf.shape(enc)[0], dtype=tf.float32)
init_state_bw = cell_bw.zero_state(tf.shape(enc)[0], dtype=tf.float32)
(output_fw, output_bw), (_, _) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, enc, initial_state_fw=init_state_fw, initial_state_bw=init_state_bw, dtype=tf.float32)
encoder_output = tf.concat([output_fw, output_bw], axis=-1)
dec2word = tf.layers.dense(encoder_output, numLabel)

法三使用 Transformer, 這裏可以加入位置編碼, 可以先參考這篇論文裡的Adding pixel coordinates to image features.

以下是我位置編碼的實作方式

batchSize = tf.shape(conv)[0]
idxH = tf.range(0, h, 1)
idxH = tf.tile(idxH, [w])
idxH = tf.cast(idxH, tf.int32)
# [0~(h-1)] * w
idxW = tf.range(0, w, 1)
idxW = tf.tile(idxW, [h])
idxW = tf.sort(idxW)
idxW = tf.cast(idxW, tf.int32)
# [0]*h, [1]*h, ... [(w-1)]*h
idxH = tf.one_hot(idxH, depth=h, dtype=tf.float32) # [w*h, h]
idxH = tf.expand_dims(idxH, 0)
idxW = tf.one_hot(idxW, depth=w, dtype=tf.float32) # [w*h, w]
idxW = tf.expand_dims(idxW, 0)
batchSize = tf.shape(enc)[0]
idxW = tf.tile(idxW, [batchSize, 1, 1])
idxH = tf.tile(idxH, [batchSize, 1, 1])
# Adding pixel coordinates to image features.
# https://arxiv.org/abs/1704.03549
enc = tf.concat([enc, idxW, idxH], axis=-1)

接下來

for i in range(2):
enc, att = multihead_attention(enc, enc, enc)
enc = ff(enc, 32)
dec2word = tf.layers.dense(enc, numLabel)

最後計算loss時, 因為圖片大小固定, 所以logit_length裡的值也是固定, label_length 會依據每張圖片裡的字母數量不同而改變

yl = length(labels)
_, l, _ = dec2word.get_shape().as_list()
loss = tf.nn.ctc_loss_v2(
labels=labels,
logits=tf.transpose(dec2word, [1, 0, 2]),
label_length=yl,
logit_length=tf.ones_like(yl) * l,
logits_time_major=True,
)
loss = tf.reduce_mean(loss)

怎樣計算 label_length 可以使用以下的方式, 先把非0的變成1, 最後加總就是字母數了

'''
Args:
sequence: A tensor with (N, maxlengh). padding is 0
Returns:
A tensor with the (N).
'''
used = tf.sign(sequence)
length = tf.reduce_sum(used, 1)
length = tf.cast(length, tf.int32)

上面使用 CTC loss訓練, 其實也可以直接使用padding來訓練, model 最後的維度是(batch, 6, 27), 這裏把 index 0 當成 padding, 假設有一張圖片的label原本是rdg, 實際在訓練時就用rdg000代表, 這裏直接補padding到序列長度, 也就是6, 我們的loss就用最常見的 softmax_cross_entropy_with_logits 就好

loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels_, logits=dec2word)

Reference

https://arxiv.org/abs/1704.03549

--

--