TensorFlow 2.0 中文手寫字識別(漢字OCR)
TensorFlow 2.0 中文手寫字識別(漢字OCR)
在開始之前,必須要說明的是,本教程完全基於TensorFlow2.0 介面編寫,請誤與其他古老的教程混為一談,本教程除了手把手教大家完成這個挑戰性任務之外,更多的會教大家如何分析整個調參過程的思考過程,力求把人工智慧演算法工程師日常的工作通過這個例子毫無保留的展示給大家。另外,我們建立了一個高端演算法分享平台,希望得到大家的支持:http://manaai.cn , 也歡迎大家來我們的AI社區交流: http://talk.strangeai.pro
還在玩minist?fashionmnist?不如來嘗試一下類別多大3000+的漢字手寫識別吧!!雖然以前有一些文章教大家如何操作,但是大多比較古老,這篇文章將用全新的TensorFlow 2.0 來教大家如何搭建一個中文OCR系統!
讓我們來看一下,相比於簡單minist識別,漢字識別具有哪些難點:
- 搜索空間空前巨大,我們使用的數據集1.0版本漢字就多大3755個,如果加上1.1版本一起,總共漢字可以分為多達7599+個類別!這比10個阿拉伯字母識別難度大很多!
- 數據集處理挑戰更大,相比於mnist和fasionmnist來說,漢字手寫字體識別數據集非常少,而且僅有的數據集數據預處理難度非常大,非常不直觀,但是,千萬別嚇到,相信你看完本教程一定會收貨滿滿!
- 漢字識別更考驗選手的建模能力,還在分類花?分類貓和狗?隨便搭建的幾層在搜索空間巨大的漢字手寫識別里根本不work!你現在是不是想用很深的網路躍躍欲試?更深的網路在這個任務上可能根本不可行!!看完本教程我們就可以一探究竟!總之一句話,模型太簡單和太複雜都不好,甚至會發散!(想親身體驗模型訓練發散抓狂的可以來嘗試一下!)。
但是,挑戰這個任務也有很多好處:
- 本教程基於TensorFlow2.0,從數據預處理,圖片轉Tensor以及Tensor的一系列騷操作都包含在內!做完本任務相信你會對TensorFlow2.0 API有一個很深刻的認識!
- 如果你是新手,通過這個教程你完全可以深入體會一下調參(或者說隨意修改網路)的糾結性和蛋疼性!
本項目實現了基於CNN的中文手寫字識別,並且採用標準的tensorflow 2.0 api 來構建!相比對簡單的字母手寫識別,本項目更能體現模型設計的精巧性和數據增強的熟練操作性,並且最終設計出來的模型可以直接應用於工業場合,比如 票據識別, 手寫文本自動掃描 等,相比於百度api介面或者QQ介面等,具有可優化性、免費性、本地性等優點。
數據準備
在開始之前,先介紹一下本項目所採用的數據信息。我們的數據全部來自於CASIA的開源中文手寫字數據集,該數據集分為兩部分:
- CASIA-HWDB:離線的HWDB,我們僅僅使用1.0-1.2,這是單字的數據集,2.0-2.2是整張文本的數據集,我們暫時不用,單字裡面包含了約7185個漢字以及171個英文字母、數字、標點符號等;
- CASIA-OLHWDB:在線的HWDB,格式一樣,包含了約7185個漢字以及171個英文字母、數字、標點符號等,我們不用。
其實你下載1.0的train和test差不多已經夠了,可以直接運行 dataset/get_hwdb_1.0_1.1.sh 下載。原始數據下載鏈接點擊這裡.由於原始數據過於複雜,我們使用一個類來封裝數據讀取過程,這是我們展示的效果:
<p align="center">

</p>
看到這麼密密麻麻的文字相信連人類都.... 開始頭疼了,這些複雜的文字能夠通過一個神經網路來識別出來??答案是肯定的.... 不有得感嘆一下神經網路的強大。。上面的部分文字識別出來的結果是這樣的:
<p align="center">

</p>
關於數據的處理部分,從伺服器下載到的原始數據是 trn_gnt.zip 解壓之後是 gnt.alz, 需要再次解壓得到一個包含 gnt文件的文件夾。裡面每一個gnt文件都包含了若干個漢字及其標註。直接處理比較麻煩,也不方便抽取出圖片再進行操作,雖然轉為圖片存入文件夾比較直觀,但是不適合批量讀取和訓練, 後面我們統一轉為tfrecord進行訓練。
更新:實際上,由於單個漢字圖片其實很小,差不多也就最大80x80的大小,這個大小不適合轉成圖片保存到本地,因此我們將hwdb原始的二進位保存為tfrecord。同時也方便後面訓練,可以直接從tfrecord讀取圖片進行訓練。

在我們存儲完成的時候大概處理了89萬個漢字,總共漢字的空間是3755個漢字。由於我們暫時僅僅使用了1.0,所以還有大概3000個漢字沒有加入進來,但是處理是一樣。使用本倉庫來生成你的tfrecord步驟如下:
cd dataset && python3 convert_to_tfrecord.py, 請注意我們使用的是tf2.0;- 你需要修改對應的路徑,等待生成完成,大概有89萬個example,如果1.0和1.1都用,那估計得double。
模型構建
關於我們採用的OCR模型的構建,我們構建了3個模型分別做測試,三個模型的複雜度逐漸的複雜,網路層數逐漸深入。但是到最後發現,最複雜的那個模型竟然不收斂。這個其中一個稍微簡單模型的訓練過程:

大家可以看到,準確率可以在短時間內達到87%非常不錯,測試集的準確率大概在40%,由於測試集中的樣本在訓練集中完全沒有出現,相對訓練集的準確率來講偏低。可能原因無外乎兩個,一個事模型泛化性能不強,另外一個原因是訓練還不夠。
不過好在這個簡單的模型也能達到訓練集90%的準確率,its a good start. 讓我們來看一下如何快速的構建一個OCR網路模型:
def build_net_003(input_shape, n_classes):
model = tf.keras.Sequential([
layers.Conv2D(input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
padding=same, activation=relu),
layers.MaxPool2D(pool_size=(2, 2), padding=same),
layers.Conv2D(filters=64, kernel_size=(3, 3), padding=same),
layers.MaxPool2D(pool_size=(2, 2), padding=same),
?
layers.Flatten(),
layers.Dense(n_classes, activation=softmax)
])
return model
這是我們使用keras API構建的一個模型,它足夠簡單,僅僅包含兩個卷積層以及兩個maxpool層。下面我們讓大家知道,即便是再簡單的模型,有時候也能發揮出巨大的用處,對於某些特定的問題可能比更深的網路更有用途。關於這部分模型構建大家只要知道這麼幾點:
- 如果你只是構建序列模型,沒有太fancy的跳躍鏈接,你可以直接用
keras.Sequential來構建你的模型; - Conv2D中最好指定每個參數的名字,不要省略,否則別人不知道你的寫的事輸入的通道數還是filters。
最後,在你看完本篇博客後,並準備自己動手復現這個教程的時候, 可以思考一下為什麼下面這個模型就發散了呢?(僅僅稍微複雜一點):
?
def build_net_002(input_shape, n_classes):
model = tf.keras.Sequential([
layers.Conv2D(input_shape=input_shape, filters=64, kernel_size=(3, 3), strides=(1, 1),
padding=same, activation=relu),
layers.MaxPool2D(pool_size=(2, 2), padding=same),
layers.Conv2D(filters=128, kernel_size=(3, 3), padding=same),
layers.MaxPool2D(pool_size=(2, 2), padding=same),
layers.Conv2D(filters=256, kernel_size=(3, 3), padding=same),
layers.MaxPool2D(pool_size=(2, 2), padding=same),
?
layers.Flatten(),
layers.Dense(1024, activation=relu),
layers.Dense(n_classes, activation=softmax)
])
return model
數據輸入
其實最複雜的還是數據準備過程啊。這裡著重說一下,我們的數據存入tfrecords中的事image和label,也就是這麼一個example:
example = tf.train.Example(features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
image: tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
width: tf.train.Feature(int64_list=tf.train.Int64List(value=[w])),
height: tf.train.Feature(int64_list=tf.train.Int64List(value=[h])),
}))
然後讀取的時候相應的讀取即可,這裡告訴大家幾點坑爹的地方:
- 將numpyarray的bytes存入tfrecord跟將文件的bytes直接存入tfrecord解碼的方式事不同的,由於我們的圖片數據不是來自於本地文件,所以我們使用了一個tobytes()方法存入的事numpy array的bytes格式,它實際上並不包含維度信息,所以這就是坑爹的地方之一,如果你不同時存儲width和height,你後面讀取的時候便無法知道維度,存儲tfrecord順便存儲圖片長寬事一個好的習慣.
- 關於不同的存儲方式解碼的方法有坑爹的地方,比如這裡我們存儲numpy array的bytes,通常情況下,你很難知道如何解碼。。(不看本教程應該很多人不知道)
最後load tfrecord也就比較直觀了:
def parse_example(record):
features = tf.io.parse_single_example(record,
features={
label:
tf.io.FixedLenFeature([], tf.int64),
image:
tf.io.FixedLenFeature([], tf.string),
})
img = tf.io.decode_raw(features[image], out_type=tf.uint8)
img = tf.cast(tf.reshape(img, (64, 64)), dtype=tf.float32)
label = tf.cast(features[label], tf.int64)
return {image: img, label: label}
?
?
def parse_example_v2(record):
"""
latest version format
:param record:
:return:
"""
features = tf.io.parse_single_example(record,
features={
width:
tf.io.FixedLenFeature([], tf.int64),
height:
tf.io.FixedLenFeature([], tf.int64),
label:
tf.io.FixedLenFeature([], tf.int64),
image:
tf.io.FixedLenFeature([], tf.string),
})
img = tf.io.decode_raw(features[image], out_type=tf.uint8)
# we can not reshape since it stores with original size
w = features[width]
h = features[height]
img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
label = tf.cast(features[label], tf.int64)
return {image: img, label: label}
?
?
def load_ds():
input_files = [dataset/HWDB1.1trn_gnt.tfrecord]
ds = tf.data.TFRecordDataset(input_files)
ds = ds.map(parse_example)
return ds
這個v2的版本就是兼容了新的存入長寬的方式,因為我第一次生成的時候就沒有保存。。。最後入坑了。注意這行代碼:
img = tf.io.decode_raw(features[image], out_type=tf.uint8)
它是對raw bytes進行解碼,這個解碼跟從文件讀取bytes存入tfrecord的有著本質的不同。同時注意type的變化,這裡以unit8的方式解碼,因為我們存儲進去的就是uint8.
訓練過程
不瞞你說,我一開始寫了一個很複雜的模型,訓練了大概一個晚上結果準確率0.00012, 發散了。後面改成了更簡單的模型才收斂。整個過程的訓練pipleline:
def train():
all_characters = load_characters()
num_classes = len(all_characters)
logging.info(all characters: {}.format(num_classes))
train_dataset = load_ds()
train_dataset = train_dataset.shuffle(100).map(preprocess).batch(32).repeat()
?
val_ds = load_val_ds()
val_ds = val_ds.shuffle(100).map(preprocess).batch(32).repeat()
?
for data in train_dataset.take(2):
print(data)
?
# init model
model = build_net_003((64, 64, 1), num_classes)
model.summary()
logging.info(model loaded.)
?
start_epoch = 0
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
start_epoch = int(latest_ckpt.split(-)[1].split(.)[0])
model.load_weights(latest_ckpt)
logging.info(model resumed from: {}, start at epoch: {}.format(latest_ckpt, start_epoch))
else:
logging.info(passing resume since weights not there. training from scratch)
?
if use_keras_fit:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[accuracy])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(ckpt_path,
save_weights_only=True,
verbose=1,
period=500)
]
try:
model.fit(
train_dataset,
validation_data=val_ds,
validation_steps=1000,
epochs=15000,
steps_per_epoch=1024,
callbacks=callbacks)
except KeyboardInterrupt:
model.save_weights(ckpt_path.format(epoch=0))
logging.info(keras model saved.)
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), cn_ocr.h5))
在本系列教程開篇之際,我們就立下了幾條準則,其中一條就是handle everything, 從這裡就能看出,它事一個很穩健的訓練代碼,同事也很自動化:
- 自動尋找之前保存的最新模型;
- 自動保存模型;
- 捕捉ctrl + c事件保存模型。
- 支持斷點續訓練
大家在以後編寫訓練代碼的時候其實可以保持這個好的習慣。
OK,整個模型訓練起來之後,可以在短時間內達到95%的準確率:

效果還是很不錯的!
模型測試
最後模型訓練完了,時候測試一下模型效果到底咋樣。我們使用了一些簡單的文字來測試:
這個字寫的還真的。。。。具有鬼神之勢。相信普通人類大部分字都能認出來,不過有些字還真的。。。。不好認。看看神經網路的表現怎麼樣!

這是大概2000次訓練的結果, 基本上能識別出來了!神經網路的認字能力還不錯的! 收工!
總結
通過本教程,我們完成了使用tensorflow 2.0全新的API搭建一個中文漢字手寫識別系統。模型基本能夠實現我們想要的功能。要知道,這個模型可是在搜索空間多大3755的類別當中準確的找到最相似的類別!!通過本實驗,我們有幾點心得:
- 神經網路不僅僅是在學習,它具有一定的想像力!!比如它的一些看著很像的字:拜-佯, 扮-撈,笨-苯.... 這些字如果手寫出來,連人都比較難以辨認!!但是大家要知道這些字在類別上並不是相領的!也就是說,模型具有一定的聯想能力!
- 不管問題多複雜,要敢於動手、善於動手。
最後希望大家對本文點個贊,編寫教程不容易。希望大家多多支持。笨教程將支持為大家輸出全新的tensorflow2.0教程!歡迎關注!!
本文所有代碼開源在 (如果由於git倉庫定期清理原因找不到項目,參考這個平台:http://manaai.cn):
https://github.com/jinfagang/ocrcn_tf2.git記得隨手star哦!!
我們的AI社區:
奇點AI社區
全球最大的開源AI代碼平台:
神力AI(MANA)-國內最大的AI代碼平台
推薦閱讀:
※三分鐘訓練眼球追蹤術,AI就知道你在看哪個妹子
※3-04 TensorFlow入門-輸入與優化
※神經網路參數選擇(keras,tensorflow)
※[L2]TensorFlow模型持久化~模型載入
※TensorFlow小介紹 從0到1學習Unity開發AI應用
TAG:人工智慧 | 深度學習(DeepLearning) | TensorFlow |
