close


第七章 PyTorch 小技巧匯總

第七章簡介

本章介紹開發過程中常用的程式碼片段、工具模組和技巧等,初步設計有模型保存與載入、模型Finetune技巧、GPU使用技巧、訓練代碼框架等。

本章小結會續更新,將工作中遇到的小技巧分享出來。

 

7.1 模型保存與載入

保存與載入的概念(序列化與反序列化)

模型訓練完畢之後,肯定想要把它保存下來,供以後使用,不需要再次去訓練。

那麼在pytorch中如何把訓練好的模型,保存,保存之後又如何載入呢?

這就用需要序列化與反序列化,序列化與反序列化的概念如下圖所示:

<<AI人工智慧 PyTorch自學>> 第七章 PyTor

 

因為在記憶體中的資料,運行結束會進行釋放,所以我們需要將資料保存到硬碟中,以二進位序列的形式進行長久存儲,便於日後使用。

序列化即把物件轉換為位元組序列的過程,反序列化則把位元組序列恢復為物件。

pytorch中,物件就是模型,所以我們常常聽到序列化和反序列化,就是將訓練好的模型從記憶體中保存到硬碟裡,當要使用的時候,再從硬碟中載入。

torch.save / torch.load

pytorch提供的序列化與反序列化函數分別是

1.

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

Copy

功能:保存物件到硬碟中

主要參數:

  • obj- 對象
  • f - 檔路徑

2.

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

Copy

功能:載入硬碟中物件

主要參數:

  • f - 檔路徑
  • map_location - 指定存儲位置,如map_location='cpu' map_location={'cuda:1':'cuda:0'}

這裡的map_location大有文章,經常需要手動設置,否者會報錯。具體可參考以下形式:

GPU->CPUtorch.load(model_path, map_location='cpu')

CPU->GPUtorch.load(model_path, map_location=lambda storage, loc: storage)

兩種保存方式

pytorch保存模型有兩種方式

  1. 保存整個模型
  2. 保存模型參數

我們通過示意圖來區分兩者之間的差異

<<AI人工智慧 PyTorch自學>> 第七章 PyTor

 

從上圖左邊知道法1保存整個nn.Module 而法2只保存模型的參數資訊。

我們知道一個module當中包含了很多資訊,不僅僅是模型的參數 parameters,還包含了buffers, hooksmodules等一系列資訊。

對於模型應用,最重要的是模型的parameters,其餘的資訊是可以通過model 類再去構建的,所以模型保存就有兩種方式

  1. 所有內容都保存;
  2. 僅保存模型的parameters

通常,我們只需要保存模型的參數,在使用的時候再通過load_state_dict方法載入參數。

由於第一種方法不常用,並且在載入過程中還需要指定的類方法,因此不做演示也不推薦。

對於第二種方法的代碼十分簡單,請看示例:

net_state_dict = net.state_dict()

torch.save(net_state_dict, "my_model.pth")

Copy

常用的程式碼片段

在模型開發過程中,往往不是一次就能訓練好模型,經常需要反復訓練,因此需要保存訓練的狀態資訊,以便於基於某個狀態繼續訓練,這就是常說的resume,可以理解為中斷點續訓練。

在整個訓練階段,除了模型參數需要保存,還有優化器的參數、學習率調整器的參數和反覆運算次數等資訊也需要保存,因此推薦在訓練時,採用以下程式碼片段進行模型保存。以下代碼來自torchvision訓練腳本

checkpoint = {

    "model": model_without_ddp.state_dict(),

    "optimizer": optimizer.state_dict(),

    "lr_scheduler": lr_scheduler.state_dict(),

    "epoch": epoch,

}

path_save = "model_{}.pth".format(epoch)

torch.save(checkpoint, path_save

 

# =================== resume ===============

# resume

checkpoint = torch.load(path_save, map_location="cpu")

model.load_state_dict(checkpoint["model"])

optimizer.load_state_dict(checkpoint["optimizer"])

lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

start_epoch = checkpoint["epoch"] + 1

Copy

小結

模型保存與載入比較簡單,需要注意的有兩點:

  1. torch.load的時候注意map_location的設置;
  2. 理解checkpoint resume的概念,以及訓練過程是需要模型、優化器、學習率調整器和已反覆運算次數的共同配合。

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()