第七章 PyTorch 小技巧匯總
第七章簡介
本章介紹開發過程中常用的程式碼片段、工具模組和技巧等,初步設計有模型保存與載入、模型Finetune技巧、GPU使用技巧、訓練代碼框架等。
本章小結會續更新,將工作中遇到的小技巧分享出來。
7.1 模型保存與載入
保存與載入的概念(序列化與反序列化)
模型訓練完畢之後,肯定想要把它保存下來,供以後使用,不需要再次去訓練。
那麼在pytorch中如何把訓練好的模型,保存,保存之後又如何載入呢?
這就用需要序列化與反序列化,序列化與反序列化的概念如下圖所示:
因為在記憶體中的資料,運行結束會進行釋放,所以我們需要將資料保存到硬碟中,以二進位序列的形式進行長久存儲,便於日後使用。
序列化即把物件轉換為位元組序列的過程,反序列化則把位元組序列恢復為物件。
在pytorch中,物件就是模型,所以我們常常聽到序列化和反序列化,就是將訓練好的模型從記憶體中保存到硬碟裡,當要使用的時候,再從硬碟中載入。
torch.save / torch.load
pytorch提供的序列化與反序列化函數分別是
1.
torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
Copy
功能:保存物件到硬碟中
主要參數:
- obj- 對象
- f - 檔路徑
2.
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
Copy
功能:載入硬碟中物件
主要參數:
- f - 檔路徑
- map_location - 指定存儲位置,如map_location='cpu', map_location={'cuda:1':'cuda:0'}
這裡的map_location大有文章,經常需要手動設置,否者會報錯。具體可參考以下形式:
GPU->CPU:torch.load(model_path, map_location='cpu')
CPU->GPU:torch.load(model_path, map_location=lambda storage, loc: storage)
兩種保存方式
pytorch保存模型有兩種方式
- 保存整個模型
- 保存模型參數
我們通過示意圖來區分兩者之間的差異
從上圖左邊知道法1保存整個nn.Module, 而法2只保存模型的參數資訊。
我們知道一個module當中包含了很多資訊,不僅僅是模型的參數 parameters,還包含了buffers, hooks和modules等一系列資訊。
對於模型應用,最重要的是模型的parameters,其餘的資訊是可以通過model 類再去構建的,所以模型保存就有兩種方式
- 所有內容都保存;
- 僅保存模型的parameters。
通常,我們只需要保存模型的參數,在使用的時候再通過load_state_dict方法載入參數。
由於第一種方法不常用,並且在載入過程中還需要指定的類方法,因此不做演示也不推薦。
對於第二種方法的代碼十分簡單,請看示例:
net_state_dict = net.state_dict()
torch.save(net_state_dict, "my_model.pth")
Copy
常用的程式碼片段
在模型開發過程中,往往不是一次就能訓練好模型,經常需要反復訓練,因此需要保存訓練的“狀態資訊”,以便於基於某個狀態繼續訓練,這就是常說的resume,可以理解為中斷點續訓練。
在整個訓練階段,除了模型參數需要保存,還有優化器的參數、學習率調整器的參數和反覆運算次數等資訊也需要保存,因此推薦在訓練時,採用以下程式碼片段進行模型保存。以下代碼來自torchvision的訓練腳本。
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
}
path_save = "model_{}.pth".format(epoch)
torch.save(checkpoint, path_save
# =================== resume ===============
# resume
checkpoint = torch.load(path_save, map_location="cpu")
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
start_epoch = checkpoint["epoch"] + 1
Copy
小結
模型保存與載入比較簡單,需要注意的有兩點:
- torch.load的時候注意map_location的設置;
- 理解checkpoint resume的概念,以及訓練過程是需要模型、優化器、學習率調整器和已反覆運算次數的共同配合。