7.4 模型訓練代碼範本
一個良好的訓練代碼,可以有助於分析和超參調優,本節將以torchvision提供的分類模型訓練代碼為基礎,編寫適合自己的訓練代碼框架。
torchvision還提供了分割、檢測、相似性學習和視頻分類的 訓練腳本,可以參考https://pytorch.org/vision/stable/training_references.html。
在分類的train.py中,共計501行代碼,下面我們提煉出核心內容,在cifar10資料集上完成resnet-8的訓練。
提煉後,代碼核心內容包括:
- 參數設置部分採用argparse模組進行配置,便於伺服器上訓練,以及超參數記錄;
- 日誌模組,包括logging模組記錄文本資訊.log檔,以及tensorboard部分的視覺化內容;
- 訓練模組封裝為通用類——ModelTrainer
- 模型保存部分
一、參數設置
在伺服器上進行訓練時,通常採用命令列啟動,或時採用sh腳本批量訓練,這時候就需要從命令列傳入一些參數,用來調整模型超參。
例如學習率想從0.1改為0.01,按以往代碼,需要進入.py檔,修改代碼,保存代碼,運行代碼。
這樣操作明顯欠妥,因此通常會採用argparse模組,將經常需要調整的參數,可以從命令列中接收。
在代碼中,採用了函數get_args_parser()實現,有了args,還可以將它記錄到日誌中,便於複現以及查看模型的超參數設置,便於跟蹤。
二、日誌模組
模型訓練的日誌很重要,它用於指導下一次實驗的超參數如何調整。
代碼中採用借助logging模組構建一個logger,並且以時間戳記(年月日-時分秒)的形式創建資料夾,便於日誌管理。
在logger中使用logger.info函數代替print函數,可以實現在終端展示資訊,還可以將其保存到日誌資料夾下的log.log文件,便於溯源。
三、訓練模組
訓練過程比較固定,因此會將其封裝成 train_one_epoch和evaluate的兩個函數,從這兩個函數中需要返回我們關心的指標,如loss,accuracy,混淆矩陣等。
四、指標統計模組
之前的代碼中,loss和accuracy需要手動記錄每個值,然後取平均,除了它們兩個,深度學習訓練中還有許多指標都需要類似的操作。
因此,可以抽象出一個AverageMeter類,用於記錄需要求取平均值的那些指標。
AverageMeter類的使用,使得代碼更簡潔,下面一同分析一下。
運行代碼當訓練完成後,可在輸出目錄下得到以時間戳記為資料夾的日誌目錄,裡面包括loss、accuracy、混淆矩陣視覺化圖,最優模型checkpoint。
小結
訓練模型的代碼結構可以千變萬化,每個人結合自己的風格進行編寫,本節代碼也是吸取了多個代碼的精華,當然還有不足之處,後續會慢慢補上,這裡提供一個整體思路,知道代碼中需要什麼。
建議參考以下訓練代碼結構:
PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
TIMM:
https://github.com/rwightman/pytorch-image-models/blob/master/train.py