4.4 Module常用函數
本小節匯總介紹Module常用的方法,由於文檔中是按首字母排序展示所有方法,未按用途進行歸類,不便於理解各函數之間的關係。在這裡,特別將具有相似功能的相關函數歸納整理,供大家參考學習。
常用方法包括:
- 設置模型訓練、評估模式
- eval
- train
- 設置模型存放在cpu/gpu/xpu
- cpu
- cuda
- to
- xpu
- 獲取模型參數、載入權重參數
- load_state_dict
- state_dict
- 管理模型的modules, parameters, sub_module
- parameters
- children
- modules
- named_children
- named_modules
- named_parameters
- get_parameter
- get_submodule
- add_module
- 設置模型的參數精度,可選半精度、單精確度、雙精度等
- bfloat16
- half
- float
- double
- 對子模組執行特定功能
- apply
- zero_grad
以上是不完全的列舉,有些非高頻使用的函數請到文檔中查閱。下面通過簡介和配套代碼的形式學習上述函數的使用。
設置模型訓練、評估模式
eval:設置模型為評估模式,這一點與上一小節介紹的BN,Dropout息息相關,評估模式下模型的某些層執行的操作與訓練狀態下是不同的。
train:設置模型為訓練模式,如BN層需要統計running_var這些統計資料,Dropout層需要執行隨機失活等。
使用方法過於簡單,無需代碼展示。
設置模型存放在cpu/gpu
對於gpu的使用會在後面設置單獨小節詳細介紹,由於這裡是基礎學習,暫時可不考慮運算速度問題。這裡既然遇到了相關的概念,就簡單說一下。
pytorch可以利用gpu進行加速運算,早期只支援NVIDIA公司的GPU,現在也逐步開始支持AMD的GPU。使用gpu進行運算的方法很簡單,就是把需要運算的資料放到gpu即可。方法就是 xxx.cuda(),若想回到cpu運算,那就需要xxx.cpu()即可。但有一個更好的方法是to(),to方法可將物件放到指定的設備中去,如to.("cpu") 、 to.("cuda)、to("cuda:0")等。
cpu:將Module放到cpu上。
cuda:將Module放到cuda上。為什麼是cuda不是gpu呢?因為CUDA(Compute Unified Device Architecture)是NVIDIA推出的運算平臺,資料是放到那上面進行運算,而gpu可以有很多個品牌,因此用cuda更合理一些。
to:將Module放到指定的設備上。
關於to通常會配備torch.cuda.is_available()使用,請看配套代碼學習。
獲取模型參數、載入權重參數
模型訓練完畢後,我們需要保存的核心內容是模型參數,這樣可以供下次使用,或者是給別人進行finetune。相信大家都用ImageNet上的預訓練模型,而使用方法就是官方訓練完畢後保存模型的參數,供我們下載,然後載入到自己的模型中。在這裡就涉及兩個重要操作:保存模型參數與載入模型參數,分別要用到以下兩個函數。
state_dict:返回參數字典。key是告訴你這個權重參數是放到哪個網路層。
load_state_dict:將參數字典中的參數複製到當前模型中。這裡的複製要求key要一一對應,若key對不上,自然模型不知道要把這個參數放到哪裡去。絕大多數開發者都會在load_state_dict這裡遇到過報錯,如
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: xxxxxxxx
Unexpected key(s) in state_dict: xxxxxxxxxx
Copy
這通常是拿到的參數字典與模型當前的結構不匹配。
對於load_state_dict函數,還有兩個參數可以設置,請看原型:
參數:
- state_dict (dict) – a dict containing parameters and persistent buffers.
- strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True
返回項
- missing_keys is a list of str containing the missing keys
- unexpected_keys is a list of str containing the unexpected keys
上述兩個方法具體的使用請看配套代碼。
管理模型的modules, parameters, sub_module
模型中需要管理的主要是parameter與module,每個物件都有兩種方式讀取,分別是帶名字和不帶名字的。針對module還有一個稱為children的方法,它與modules方法最大的不同在於modules會返回module本身。具體差異通過配套代碼一看便明瞭。
parameters:返回一個反覆運算器,反覆運算器可拋出Module的所有parameter物件
named_parameters:作用同上,不僅可得到parameter物件,還會給出它的名稱
modules:返回一個反覆運算器,反覆運算器可以拋出Module的所有Module物件,注意:模型本身也是module,所以也會獲得自己。
named_modules:作用同上,不僅可得到Module物件,還會給出它的名稱
children:作用同modules,但不會返回Module自己。
named_children:作用同named_modules,但不會返回Module自己。
獲取某個參數或submodule
當想查看某個部分資料時,可以通過get_xxx方法獲取模型特定位置的資料,可獲取parameter、submodule,使用方法也很簡單,只需要傳入對應的name即可。
get_parameter
get_submodule
設置模型的參數精度,可選半精度、單精確度、雙精度等
為了調整模型占存儲空間的大小,可以設置參數的資料類型,預設情況是float32位元(單精確度),在一些場景可採用半精度、雙精度等,以此改變模型的大小或精度。Module提供了幾個轉換權重參數精度的方法,分別如下:
- half:半精度
- float:單精確度
- double:雙精度
- bfloat16:Brain Floating Point 是Google開發的一種資料格式,詳細參見wikipedia
對子模組執行特定功能
zero_grad:將所有參數的梯度設置為0,或者None
apply:對所有子Module執行指定fn(函數),常見於參數初始化。這個可以參見配套代碼。
小結
本節對Module的常用API函數進行了介紹,包括模型兩種狀態,模型存儲於何種設備,模型獲取參數,載入參數,管理模型的modules,設置模型參數的精度,對模型子模組執行特定功能。
由於Module是核心模組,其涉及的API非常多,短時間不好消化,建議大家結合代碼用例,把這些方法都過一遍,留個印象,待日後專案開發需要的時候知道有這些函數可以使用即可。
下一小節將介紹Module中的Hook函數。
留言列表