標籤:

pytorch1.0 模型打包指南(部署)

pytorch1.0 模型打包指南(部署)

來自專欄 pytorch填坑系列3 人贊了文章

pytorch1.0預覽版已於2018-09-20發布,基本api和0.4保持一致,主要的變化在部署模型這一塊

在官網教程中已經有了詳細的說明,這裡簡要提一下如何保存和載入。

Loading a PyTorch Model in C++?

pytorch.org

保存

參考官網教程,在訓練完之後,會保存一個模型的權重,接下來,需要重新將權重載入到網路中,並進行保存,示例代碼如下:

import torchfrom models.crnn import CRNNdef save(net, input, save_path): net.eval() traced_script_module = torch.jit.trace(net, input) traced_script_module.save(save_path)def load(model_path): return torch.jit.load(model_path)if __name__ == __main__: input = torch.Tensor(10, 3, 32, 320) model_path = ./model.pth net = CRNN(32, 3, 10, 256) net.load_state_dict(torch.load(model_path)) save(net, input, ./model.pt)

此處以crnn為例,可以看到,在保存之前,我們需要先將模型設置為測試模式,不設置的話,dropout和bn這些和訓練狀態相關的層會報錯(算是一個小坑吧)。

保存分為兩步

  1. 使用torch.jit.trace()pytorch模型轉換為Torch Script ,該方法會返回一ScriptModule
  2. 調用返回的ScriptModulesave()方法將模型保存下來。

載入

載入模型也非常簡單,只需要一句話

net = torch.jit.load(model_path)

之前使用torch.save也可以將計算圖和權重保存在一起。但是那樣保存的計算圖依賴於 原始文件,文件發生改變就不能載入了。使用1.0的方式保存的模型不依賴於原始文件,可以任意移動。

列印一下trace前後的net

之前

trace之前

之後

trace之後

可以看到relupool 都被移除了,至於為什麼目前還沒有研究清楚。等待大神的詳細解答。

推薦閱讀:

PyTorch入門代碼——訓練一個圖像分類模型(Level1)
記一次pytorch安裝過程
pytorch Module模塊學習

TAG:PyTorch |