6.3 混淆矩陣與訓練曲線視覺化

在分類任務中,通過混淆矩陣可以看出模型的偏好,而且對每一個類別的分類情況都瞭若指掌,為模型的優化提供很大説明。本節將介紹混淆矩陣概念及其視覺化。

為了演示混淆矩陣與訓練曲線,本節代碼採用cifar10資料集進行訓練,模型採用resnet系列。

數據cifar-10-python.tar.gz 可從 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下載,放到指定資料夾節課,無需解壓,代碼會自動解壓。

<<AI人工智慧 PyTorch自學>> 6.3 混淆矩陣與

混淆矩陣概念

混淆矩陣(Confusion Matrix)常用來觀察分類結果,其是一個N*N的方陣,N表示類別數。

混淆矩陣的行表示真實類別,列表示預測類別。例如,貓狗的二分類問題,有貓的圖像10張,狗的圖像30張,模型對這40張圖片進行預測,得到的混淆矩陣為

 

阿貓

阿狗

阿貓

7

3

阿狗

10

20

從第一行中可知道,10張貓的圖像中,7張預測為貓,3張預測為狗,貓的召回率(Recall)7/10 = 70%

從第二行中可知道,30張狗的圖像中,8張預測為貓,22張預測為狗,狗的召回率為20/30 = 66.7%

從第一列中可知道,預測為貓的17張圖像中,有7張是真正的貓,貓的精確度(Precision)7 / 17 = 41.17%

從第二列中可知道,預測為狗的23張圖像中,有20張是真正的狗,狗的精確度(Precision)20 / 23 = 86.96%

模型的準確率(Accuracy) 7+20 / 40 = 67.5%

可以發現通過混淆矩陣可以清晰的看出網路模型的分類情況,若再結合上顏色視覺化,可方便的看出模型的分類偏好。

本小節將介紹,混淆矩陣的統計及其視覺化。

混淆矩陣的統計

混淆矩陣的繪製將借助matplotlibimshow功能,在imshow中可對矩陣進行上色,colorbar可自行調整,如本例中採用的黑白調,也可以選擇其他的colorbar

在模型訓練中,通常以一個epoch為單位,進行混淆矩陣的統計,然後繪製,代碼思路如下:

第一步:創建混淆矩陣

獲取類別數,創建N*N的零矩陣

conf_mat = np.zeros([cls_num, cls_num])

第二步:獲取真實標籤和預測標籤

labels 為真實標籤,通常為一個batch的標籤

predicted為預測類別,與labels同長度

第三步:依據標籤為混淆矩陣計數

        for j in range(len(labels)):

            cate_i = labels[j].cpu().numpy()

            pre_i = predicted[j].cpu().numpy()

            conf_mat[cate_i, pre_i] += 1.

Copy

混淆矩陣視覺化

混淆矩陣視覺化已經封裝成一個函數show_conf_mat,函數位於 配套代碼

show_conf_mat(confusion_mat, classes, set_name, out_dir, epoch=999, verbose=False, perc=False)

參數:

    """

    混淆矩陣繪製並保存圖片

    :param confusion_mat:  nd.array

    :param classes: list or tuple, 類別名稱

    :param set_name: str, 資料集名稱 train or valid or test?

    :param out_dir:  str, 圖片要保存的資料夾

    :param epoch:  int, 第幾個epoch

    :param verbose: bool, 是否列印精度資訊

    :param perc: bool, 是否採用百分比,圖像分割時用,因分類數目過大

    :return:

    """

Copy

show_conf_mat函數內部原理就不再詳細展開,都是matplotlib的基礎知識。下圖為最終效果圖:

<<AI人工智慧 PyTorch自學>> 6.3 混淆矩陣與

show_conf_mat函數提供png的保存,不便於觀察整個訓練過程的變化,這裡借助tensorboardadd_figure功能,將每個epoch的混淆矩陣保存到tensorboard中,然後可拖拽的形式觀察模型精度的變化情況。

效果如下圖:

<<AI人工智慧 PyTorch自學>> 6.3 混淆矩陣與

從上述變化可以發現模型在反覆運算過程中的偏好,前後對比圖可很好的説明工程師分析模型的偏好。

global_step比較多的時候,toolbar無法展示每一個step,這需要在啟動tensorboard的時候設置一下參數即可

tensorboard --logdir=./Result --samples_per_plugin images=200

Copy

除了手動繪製之外,sklearn庫也提供了混淆矩陣繪製(from sklearn.metrics import confusion_matrix),這裡不再拓展。

訓練曲線繪製

除了混淆矩陣,在模型訓練過程中最重要的是觀察loss曲線的變化,loss曲線變化趨勢直接決定訓練是否需要停止,並指引我們進行參數的調整。

loss曲線是需要將訓練與驗證放在一起看的,單獨看一條曲線是不夠的,這一點需要大家瞭解模型評估中的方差與偏差的概念。

通過訓練loss看偏差,通過訓練loss與驗證loss看方差。

偏差看的是模型擬合能力是否足夠,方差是看模型泛化性能是否足夠,是否存在過擬合。

<<AI人工智慧 PyTorch自學>> 6.3 混淆矩陣與

將兩條曲線繪製到一個坐標系裡,可以借助tensorboardadd_scalars函數,具體請看代碼

在訓練集的反覆運算之後記錄訓練集的loss

writer.add_scalars('Loss_group', {'train_loss': loss_avg}, epoch)

Copy

在驗證集的反覆運算之後記錄訓練集的loss

writer.add_scalars('Loss_group', {'valid_loss': loss_avg}, epoch)

Copy

在這裡可以發現,SummaryWriter類的函數是以tag變數進行區分不同的坐標系,以上例子看出,雖然在兩個地方執行代碼,但是通過tag="Loss_group",仍舊可以把它們繪製在一個坐標系裡。

小結

以上就是在訓練過程中記錄必要的訓練資訊,用於監控模型訓練狀態。

下一節將介紹有趣的CAM視覺化實現,以及nn.Module模組中的系列hook函數使用。

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()