pytorch1.0 模型打包指南(部署)
來自專欄 pytorch填坑系列
3 人贊了文章
pytorch1.0預覽版已於2018-09-20發布,基本api和0.4保持一致,主要的變化在部署模型這一塊
在官網教程中已經有了詳細的說明,這裡簡要提一下如何保存和載入。
Loading a PyTorch Model in C++保存
參考官網教程,在訓練完之後,會保存一個模型的權重,接下來,需要重新將權重載入到網路中,並進行保存,示例代碼如下:
import torchfrom models.crnn import CRNNdef save(net, input, save_path): net.eval() traced_script_module = torch.jit.trace(net, input) traced_script_module.save(save_path)def load(model_path): return torch.jit.load(model_path)if __name__ == __main__: input = torch.Tensor(10, 3, 32, 320) model_path = ./model.pth net = CRNN(32, 3, 10, 256) net.load_state_dict(torch.load(model_path)) save(net, input, ./model.pt)
此處以crnn為例,可以看到,在保存之前,我們需要先將模型設置為測試模式,不設置的話,dropout和bn這些和訓練狀態相關的層會報錯(算是一個小坑吧)。
保存分為兩步
- 使用
torch.jit.trace()將pytorch模型轉換為Torch Script,該方法會返回一ScriptModule。 - 調用返回的
ScriptModule的save()方法將模型保存下來。
載入
載入模型也非常簡單,只需要一句話
net = torch.jit.load(model_path)
之前使用torch.save也可以將計算圖和權重保存在一起。但是那樣保存的計算圖依賴於 原始文件,文件發生改變就不能載入了。使用1.0的方式保存的模型不依賴於原始文件,可以任意移動。
列印一下trace前後的net
之前

之後
可以看到relu和pool 都被移除了,至於為什麼目前還沒有研究清楚。等待大神的詳細解答。
推薦閱讀:
※PyTorch入門代碼——訓練一個圖像分類模型(Level1)
※記一次pytorch安裝過程
※pytorch Module模塊學習
TAG:PyTorch |

