7.4 模型訓練代碼範本

一個良好的訓練代碼,可以有助於分析和超參調優,本節將以torchvision提供的分類模型訓練代碼為基礎,編寫適合自己的訓練代碼框架。

torchvision還提供了分割、檢測、相似性學習和視頻分類的 訓練腳本,可以參考https://pytorch.org/vision/stable/training_references.html

在分類的train.py中,共計501行代碼,下面我們提煉出核心內容,在cifar10資料集上完成resnet-8的訓練。

提煉後,代碼核心內容包括:

  1. 參數設置部分採用argparse模組進行配置,便於伺服器上訓練,以及超參數記錄;
  2. 日誌模組,包括logging模組記錄文本資訊.log檔,以及tensorboard部分的視覺化內容;
  3. 訓練模組封裝為通用類——ModelTrainer
  4. 模型保存部分

一、參數設置

在伺服器上進行訓練時,通常採用命令列啟動,或時採用sh腳本批量訓練,這時候就需要從命令列傳入一些參數,用來調整模型超參。

例如學習率想從0.1改為0.01,按以往代碼,需要進入.py檔,修改代碼,保存代碼,運行代碼。

這樣操作明顯欠妥,因此通常會採用argparse模組,將經常需要調整的參數,可以從命令列中接收。

在代碼中,採用了函數get_args_parser()實現,有了args,還可以將它記錄到日誌中,便於複現以及查看模型的超參數設置,便於跟蹤。

二、日誌模組

模型訓練的日誌很重要,它用於指導下一次實驗的超參數如何調整。

代碼中採用借助logging模組構建一個logger,並且以時間戳記(年月日-時分秒)的形式創建資料夾,便於日誌管理。

logger中使用logger.info函數代替print函數,可以實現在終端展示資訊,還可以將其保存到日誌資料夾下的log.log文件,便於溯源。

三、訓練模組

訓練過程比較固定,因此會將其封裝成 train_one_epochevaluate的兩個函數,從這兩個函數中需要返回我們關心的指標,如lossaccuracy,混淆矩陣等。

四、指標統計模組

之前的代碼中,lossaccuracy需要手動記錄每個值,然後取平均,除了它們兩個,深度學習訓練中還有許多指標都需要類似的操作。

因此,可以抽象出一個AverageMeter類,用於記錄需要求取平均值的那些指標。

AverageMeter類的使用,使得代碼更簡潔,下面一同分析一下。

運行代碼當訓練完成後,可在輸出目錄下得到以時間戳記為資料夾的日誌目錄,裡面包括lossaccuracy、混淆矩陣視覺化圖,最優模型checkpoint

<<AI人工智慧 PyTorch自學>> 7.4 模型訓練代

 

小結

訓練模型的代碼結構可以千變萬化,每個人結合自己的風格進行編寫,本節代碼也是吸取了多個代碼的精華,當然還有不足之處,後續會慢慢補上,這裡提供一個整體思路,知道代碼中需要什麼。

建議參考以下訓練代碼結構:

PyTorch ImageNet example

(https://github.com/pytorch/examples/tree/master/imagenet)

NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples

(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)

TIMM:

https://github.com/rwightman/pytorch-image-models/blob/master/train.py

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()