6.3 混淆矩陣與訓練曲線視覺化
在分類任務中,通過混淆矩陣可以看出模型的偏好,而且對每一個類別的分類情況都瞭若指掌,為模型的優化提供很大説明。本節將介紹混淆矩陣概念及其視覺化。
為了演示混淆矩陣與訓練曲線,本節代碼採用cifar10資料集進行訓練,模型採用resnet系列。
數據cifar-10-python.tar.gz 可從 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下載,放到指定資料夾節課,無需解壓,代碼會自動解壓。
混淆矩陣概念
混淆矩陣(Confusion Matrix)常用來觀察分類結果,其是一個N*N的方陣,N表示類別數。
混淆矩陣的行表示真實類別,列表示預測類別。例如,貓狗的二分類問題,有貓的圖像10張,狗的圖像30張,模型對這40張圖片進行預測,得到的混淆矩陣為
阿貓 |
阿狗 |
|
---|---|---|
阿貓 |
7 |
3 |
阿狗 |
10 |
20 |
從第一行中可知道,10張貓的圖像中,7張預測為貓,3張預測為狗,貓的召回率(Recall)為7/10 = 70%,
從第二行中可知道,30張狗的圖像中,8張預測為貓,22張預測為狗,狗的召回率為20/30 = 66.7%,
從第一列中可知道,預測為貓的17張圖像中,有7張是真正的貓,貓的精確度(Precision)為7 / 17 = 41.17%
從第二列中可知道,預測為狗的23張圖像中,有20張是真正的狗,狗的精確度(Precision)為20 / 23 = 86.96%
模型的準確率(Accuracy)為 7+20 / 40 = 67.5%
可以發現通過混淆矩陣可以清晰的看出網路模型的分類情況,若再結合上顏色視覺化,可方便的看出模型的分類偏好。
本小節將介紹,混淆矩陣的統計及其視覺化。
混淆矩陣的統計
混淆矩陣的繪製將借助matplotlib的imshow功能,在imshow中可對矩陣進行上色,colorbar可自行調整,如本例中採用的黑白調,也可以選擇其他的colorbar。
在模型訓練中,通常以一個epoch為單位,進行混淆矩陣的統計,然後繪製,代碼思路如下:
第一步:創建混淆矩陣
獲取類別數,創建N*N的零矩陣
conf_mat = np.zeros([cls_num, cls_num])
第二步:獲取真實標籤和預測標籤
labels 為真實標籤,通常為一個batch的標籤
predicted為預測類別,與labels同長度
第三步:依據標籤為混淆矩陣計數
for j in range(len(labels)):
cate_i = labels[j].cpu().numpy()
pre_i = predicted[j].cpu().numpy()
conf_mat[cate_i, pre_i] += 1.
Copy
混淆矩陣視覺化
混淆矩陣視覺化已經封裝成一個函數show_conf_mat,函數位於 配套代碼
show_conf_mat(confusion_mat, classes, set_name, out_dir, epoch=999, verbose=False, perc=False)
參數:
"""
混淆矩陣繪製並保存圖片
:param confusion_mat: nd.array
:param classes: list or tuple, 類別名稱
:param set_name: str, 資料集名稱 train or valid or test?
:param out_dir: str, 圖片要保存的資料夾
:param epoch: int, 第幾個epoch
:param verbose: bool, 是否列印精度資訊
:param perc: bool, 是否採用百分比,圖像分割時用,因分類數目過大
:return:
"""
Copy
show_conf_mat函數內部原理就不再詳細展開,都是matplotlib的基礎知識。下圖為最終效果圖:
show_conf_mat函數提供png的保存,不便於觀察整個訓練過程的變化,這裡借助tensorboard的add_figure功能,將每個epoch的混淆矩陣保存到tensorboard中,然後可拖拽的形式觀察模型精度的變化情況。
效果如下圖:
從上述變化可以發現模型在反覆運算過程中的偏好,前後對比圖可很好的説明工程師分析模型的偏好。
當global_step比較多的時候,toolbar無法展示每一個step,這需要在啟動tensorboard的時候設置一下參數即可
tensorboard --logdir=./Result --samples_per_plugin images=200
Copy
除了手動繪製之外,sklearn庫也提供了混淆矩陣繪製(from sklearn.metrics import confusion_matrix),這裡不再拓展。
訓練曲線繪製
除了混淆矩陣,在模型訓練過程中最重要的是觀察loss曲線的變化,loss曲線變化趨勢直接決定訓練是否需要停止,並指引我們進行參數的調整。
loss曲線是需要將訓練與驗證放在一起看的,單獨看一條曲線是不夠的,這一點需要大家瞭解模型評估中的方差與偏差的概念。
通過訓練loss看偏差,通過訓練loss與驗證loss看方差。
偏差看的是模型擬合能力是否足夠,方差是看模型泛化性能是否足夠,是否存在過擬合。
將兩條曲線繪製到一個坐標系裡,可以借助tensorboard的add_scalars函數,具體請看代碼
在訓練集的反覆運算之後記錄訓練集的loss
writer.add_scalars('Loss_group', {'train_loss': loss_avg}, epoch)
Copy
在驗證集的反覆運算之後記錄訓練集的loss
writer.add_scalars('Loss_group', {'valid_loss': loss_avg}, epoch)
Copy
在這裡可以發現,SummaryWriter類的函數是以tag變數進行區分不同的坐標系,以上例子看出,雖然在兩個地方執行代碼,但是通過tag="Loss_group",仍舊可以把它們繪製在一個坐標系裡。
小結
以上就是在訓練過程中記錄必要的訓練資訊,用於監控模型訓練狀態。
下一節將介紹有趣的CAM視覺化實現,以及nn.Module模組中的系列hook函數使用。