[L2]TensorFlow模型持久化~模型載入
前面介紹了模型的保存:
觸摸壹縷陽光:[L1]TensorFlow模型持久化~模型保存
通過TensorFlow提供tf.train.Saver類提供的save函數保存模型,生成對應的四個文件,因為TensorFlow將計算圖的結構以及圖上的變數參數值分開保存,這樣能夠為模型的載入提供方便的擴展。
1.模型載入
由於保存模型的時候TensorFlow將計算圖的結構以及計算圖上的變數參數值分開保存。所以載入模型我從計算圖的結構和計算圖上的變數參數值分別考慮。
下面還是使用[L1]TensorFlow模型持久化~模型保存中簡單的加法程序作為案例:
import tensorflow as tf
#聲明兩個變數並計算他們的和
a = tf.Variable(tf.constant(1.0,shape = [1]),name = "a")
b = tf.Variable(tf.constant(2.0,shape = [1]),name = "b")
result = a + b
#聲明tf.train.Saver()類
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#將模型保存到指定的文件中
saver.save(sess,"./model/add_model.ckpt")
對應生成的四個文件如下圖所示:

- 僅載入模型中保存的變數
在[L1]TensorFlow模型持久化~模型保存中我們也提到了,add_model.ckpt.data-00000-of-00001文件是保存TensorFlow當前變數值,而add_model.ckpt.index文件中保存的是TensorFlow當前的變數名,所以如果要載入模型中保存的變數的時候,一定不要刪除這兩個文件。
TensorFlow同樣提供了tf.train.Saver類的restore函數來載入保存的變數。前面提到保存模型時候的變數參數是依附在計算圖的結構上的,但此時我們僅僅將保存模型的變數參數載入進來,並沒有載入模型的計算圖,所以如果我們想要正常的載入保存模型的變數參數的話,就需要定義一個和保存模型時候一模一樣的計算圖結構。
所以如果想要載入變數的話,首先要定義一個和保存時候模型的結構相同的計算圖:
import tensorflow as tf
# 聲明兩個值為常數0.0的變數
a = tf.Variable(tf.constant(0.0, shape=[1]), name="a")
b = tf.Variable(tf.constant(0.0, shape=[1]), name="b")
result = a + b
saver = tf.train.Saver()
with tf.Session() as sess:
# 全局變數進行初始化
sess.run(tf.global_variables_initializer())
#將指定文件中的變數載入到模型中
saver.restore(sess, "./model/add_model.ckpt")
print("a = %d,b = %d, a + b = %d"%(sess.run(a),sess.run(b),sess.run(result)))
a = 1,b = 2, a + b = 3
關於全局變數初始化的說明:
我們知道sess.run(tf.global_variables_initializer())這句話可以對全局變數進行初始化,在運行程序的時候不能不加,所以在保存模型的時候,無論如何都要進行全局變數的初始化的。那現在有一個問題,載入模型的時候,還用不用再次執行這段話呢?
其實是不需要的,如果在上面的代碼中刪掉sess.run(tf.global_variables_initializer())這句話,依然能夠正常載入。也就是說保存模型的時候,已經對變數進行初始化了,所以不需要在載入模型的時候進行全局變數的初始化操作了。下面看一下,到底sess.run(tf.global_variables_initializer())此時是沒有作用還是起了作用但是被取代了:
import tensorflow as tf
#聲明兩個值為常數0.0的變數
a = tf.Variable(tf.constant(0.0,shape = [1]),name = "a")
b = tf.Variable(tf.constant(0.0,shape = [1]),name = "b")
result = a + b
saver = tf.train.Saver()
with tf.Session() as sess:
#全局變數進行初始化
sess.run(tf.global_variables_initializer())
print("a = %d,b = %d, a + b = %d" % (sess.run(a), sess.run(b), sess.run(result)))
print("-"*20)
#將指定文件中的變數載入到模型中
saver.restore(sess,"./model/add_model.ckpt")
print("a = %d,b = %d, a + b = %d" % (sess.run(a), sess.run(b), sess.run(result)))
a = 0,b = 0, a + b = 0
--------------------
a = 1,b = 2, a + b = 3
下面交換顯示的全局初始化變數與載入模型代碼交換:
import tensorflow as tf
#聲明兩個值為常數0.0的變數
a = tf.Variable(tf.constant(0.0,shape = [1]),name = "a")
b = tf.Variable(tf.constant(0.0,shape = [1]),name = "b")
result = a + b
saver = tf.train.Saver()
with tf.Session() as sess:
#將模型保存到指定的文件中
saver.restore(sess,"./model/add_model.ckpt")
print("a = %d,b = %d, a + b = %d" % (sess.run(a), sess.run(b), sess.run(result)))
print("-" * 20)
# 全局變數進行初始化
sess.run(tf.global_variables_initializer())
print("a = %d,b = %d, a + b = %d" % (sess.run(a), sess.run(b), sess.run(result)))
a = 1,b = 2, a + b = 3
--------------------
a = 0,b = 0, a + b = 0
通過上面的兩段代碼,我們知道其實在當前執行全局變數的初始化還是會對當前計算圖上的變數進行初始化的,因為此時我們並沒有載入保存的計算圖結構,所以此時我們必須在載入變數的模型中手動的創建一個和保存的模型一模一樣的計算圖結構。當然此時執行全局變數進行初始化是對當前計算圖上的變數進行初始化操作。
只不過我們執行了saver.restore(sess,"./model/add_model.ckpt")代碼,也就是將保存模型的變數載入了進來,如果在全局初始化變數的代碼後面,那麼此時載入進來的已經初始化之後的變數會覆蓋此前被初始化的值,就本例來說也就是a = 0,會被a = 1所覆蓋。
首先說明一點,對於a = tf.Variable(tf.constant(1.0,shape = [1]),name = "a")代碼:
1.a叫做變數名;2.name屬性指定的參數叫做變數名稱;
我們在保存模型的時候知道,在保存模型的時候,我們可以給tf.train.Saver()中傳遞參數實現一些高級的實現,比如:
- 參數指定一個列表,指定部分變數進行保存,列表中的元素是變數名;
- 參數指定一個變數名與變數名稱對應的字典來指定保存時候的對應關係,因為此時保存的時候和變數名沒有關係了,而是以變數名稱作為唯一的標識;
保存的時候可以這樣指定,其實在載入模型的時候,同樣可以這樣操作:
import tensorflow as tf
#定義和保存模型時候相同的計算圖結構
a = tf.Variable(tf.constant(0.0,shape = [1]),name = "a")
b = tf.Variable(tf.constant(0.0,shape = [1]),name = "b")
result = a + b
#只載入a這個變數
saver = tf.train.Saver([a])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, "./model/add_model.ckpt")
print("a = %d,b = %d, a + b = %d"%(sess.run(a),sess.run(b),sess.run(result)))
a = 1,b = 0, a + b = 1
說明:
1.此時如果不加sess.run(tf.global_variables_initializer()),會出現下面的異常,也就是沒有對b變數進行初始化:

因為此時我們只載入了a,saver.restore(sess, "./model/add_model.ckpt")初始化的也只有a變數,但是因為此時的計算圖結構還有定義的變數b,所以會拋出沒有對變數b進行初始化的異常。
其實載入模型就相當於從保存的文件中取出變數名稱以及變數值的(key,value)列表,此時的key也就是變數名稱,value表示的就是value。下面展示一下載入部分變數的大致流程:

載入部分變數的大致流程如下:
- 通過tf.train.Svaer參數list中的變數名找到當前計算圖上定義的變數名;
- 通過變數名找到對應定義的變數名稱;
- 通過變數名稱找到保存在add_model.ckpt.data-00000-of-00001和add_model.ckpt.index兩個文件中,簡單來說就是(key,value)的列表中的key,也就是文件中保存的變數名稱a;
- 通過key也就是變數名稱a找到對應的value值,也就是變數值,然後將此時的變數值覆蓋掉原來變數值,也就是用1.0替換掉了0.0;
通過上面的分析,保存的文件中存的是(a,1.0)和(b,2.0),那麼現在我改變當前計算圖的變數名稱代碼如下:
import tensorflow as tf
#定義和保存模型時候相同的計算圖結構,此時該了變數名a的變數名稱
a = tf.Variable(tf.constant(0.0,shape = [1]),name = "add_1")
b = tf.Variable(tf.constant(0.0,shape = [1]),name = "b")
result = a + b
#只載入a這個變數
saver = tf.train.Saver([a])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, "./model/add_model.ckpt")
print("a = %d,b = %d, a + b = %d"%(sess.run(a),sess.run(b),sess.run(result)))


接下來該在tf.train.Saver()中傳遞字典參數了,其實實質上都一樣,只要記住文件中保存的是(key,value),key是變數名稱,而value是變數值,key也就是變數名稱是唯一的標識:
import tensorflow as tf
#定義和保存模型時候相同的計算圖結構,此時該了變數名a的變數名稱
a = tf.Variable(tf.constant(0.0,shape = [1]),name = "add_1")
b = tf.Variable(tf.constant(0.0,shape = [1]),name = "add_2")
result = a + b
#只載入a這個變數
saver = tf.train.Saver({"a":b,"b":a})
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, "./model/add_model.ckpt")
print("a = %d,b = %d, a + b = %d"%(sess.run(a),sess.run(b),sess.run(result)))
a = 2,b = 1, a + b = 3
注意:
- 字典中的key可不是當前計算圖上定義變數的變數名稱,字典中的key是保存時候的key值,也就是保存時候的變數名稱;

指定參數字典載入變數:
- 通過字典中的key找到文件中保存的變數名稱,通過字典中的value找到當前計算圖中變數名;
- 將保存文件中的key對應value值覆蓋通過字典中的value找到的當前計算圖中變數名對應的變數值。
- 僅載入模型中保存的變數
前面說了很多關於載入變數,下面說一說如何載入模型。如果不希望在載入模型的時候重複定義計算圖,可以直接載入已經持久化的圖。對於載入模型的操作TensorFlow也提供了很方便的函數調用,我們還記得保存模型時候將計算圖保存到.meta後綴的文件中。那此時只需要載入這個文件即可:
import tensorflow as tf
#直接載入持久化的圖
saver = tf.train.import_meta_graph("./model/add_model.ckpt.meta")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
print(sess.run(tf.get_default_graph().get_tensor_by_name("a:0")))
print(sess.run(tf.get_default_graph().get_tensor_by_name("b:0")))
[ 3.]
[ 1.]
[ 2.]
注意:
1.會發現此時居然也能列印出數值,是不是因為add_model.ckpt.data-00000-of-00001和add_model.ckpt.index兩個文件在起作用,其實不是,我們可以把add_model.ckpt.data-00000-of-00001和add_model.ckpt.index兩個文件刪除,會發現還是能夠繼續執行程序得到結果;
2.如果我們此時把sess.run(tf.global_variables_initializer())全局變數的初始化代碼刪除,會發現

3.我們可以簡單的看成是把在保存模型的時候的計算圖結構複製到當前的結構下,也就是說:
saver = tf.train.import_meta_graph("./model/add_model.ckpt.meta")
等價於==》
a = tf.Variable(tf.constant(1.0,shape = [1]),name = "a")
b = tf.Variable(tf.constant(2.0,shape = [1]),name = "b")
result = a + b
4.此時因為沒有顯示的變數,所以此時只能通過運算節點的名稱來獲取依附在計算圖上的值。
有人會說在[L1]TensorFlow模型持久化~模型保存中不是說add_model.ckpt.meta文件保存了TensorFlow計算圖的結構嗎?為什麼也能獲取數據,其實這個文件中記錄的不僅僅是計算圖這一個結構還有節點的信息以及運行計算圖中節點所需要的元數據。簡單來說,我們可以使用運算方法的名稱在TensorFlow計算圖元圖中找到該運算節點的具體信息,當然包括此時運算節點的值。
當然此時獲取的值和通過變數的那種方式還是有很大的區別的,載入計算圖獲得的變數僅僅是節點上的值,並不能實現一些更高級的功能,而且運算節點的名稱也是很複雜的。當然你也可以將載入計算圖結構和載入變數結合起來。
參考:
1.《TensorFlow實現Google深度學習框架》
推薦閱讀:
TAG:TensorFlow | 模型 | 深度學習(DeepLearning) |
