close

7.5 torchmetrics 模型評估指標庫

模型訓練時是通過loss進行好壞的評估,因為我們採用的是loss進行方向傳播。對於人類評判好壞,往往不是通過loss值,而是採用某種評判指標。

在圖像分類任務中常用的有Accuracy(準確率)、Recall(召回率)和Precision(精確度),圖像分割中常用mIoUDice係數,目標檢測中常用mAP,由此可見不同任務的評價指標大多不一樣。

常用的指標多達幾十種,本節將介紹torchmetrics工具,它目前提供超過80種評價指標的函數,並且使用起來非常方便,值得學習。

TorchMetrics簡介與安裝

TorchMetrics Github

TorchMetrics is a collection of 80+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:

  • A standardized interface to increase reproducibility
  • Reduces Boilerplate
  • Distributed-training compatible
  • Rigorously tested
  • Automatic accumulation over batches
  • Automatic synchronization between multiple devices

安裝:

pip install torchmetrics

 

conda install -c conda-forge torchmetrics

Copy

TorchMetrics 快速上手

torchmetrics 的使用與本章第四節課中介紹的AverageMeter類似,它能夠記錄每一次的資訊,並通過.compute()函數進行匯總計算。

下面通過一個accuracy的例子,剖析torchmetrics的體系結構。

from my_utils import setup_seed

setup_seed(40)

import torch

import torchmetrics

 

metric = torchmetrics.Accuracy()

n_batches = 3

for i in range(n_batches):

    preds = torch.randn(10, 5).softmax(dim=-1)

    target = torch.randint(5, (10,))

    acc = metric(preds, target)  # 單次計算,並記錄本次資訊。通過維護tp, tn, fp, fn來記錄所有資料

    print(f"Accuracy on batch {i}: {acc}")

 

acc_avg = metric.compute()

print(f"Accuracy on all data: {acc_avg}")

tp, tn, fp, fn = metric.tp, metric.tn, metric.fp, metric.fn

print(tp, tn, fp, fn, sum([tp, tn, fp, fn]))

metric.reset()

Copy

Accuracy on batch 0: 0.30000001192092896

Accuracy on batch 1: 0.10000000149011612

Accuracy on batch 2: 0.20000000298023224

Accuracy on all data: 0.20000000298023224

tensor(6) tensor(96) tensor(24) tensor(24) tensor(150)

Copy

torchmetrics的使用可以分以下三步:

​ 1.創建指標評價器

​ 2.反覆運算中進行"update"forwardupdateforward均可記錄每次資料資訊

​ 3.計算所有資料指標

TorchMetrics代碼結構

這裡提到forward,正是第四章中nn.Moduleforward TorchMetrics所有指標均繼承了nn.Module,因此可以看到這樣的用法。

acc = metric(preds, target)

Copy

下面進入 torchmetrics\classification\accuracy.py 中觀察 Accuracy到底是什麼。

可以看到Accuracy類只有3個函數,分別是__init__, update, compute,其作用就如上文所述。

再看繼承關係,Accuracy --> StatScores --> Metric --> nn.Module + ABC

Metric正如文檔所說“The base Metric class is an abstract base class that are used as the building block for all other Module metrics.”,是torchmetrics所有類的基類,它實現forward函數,因此才有像這樣的調用: acc = metric(preds, target)

Accuracy 更新邏輯

torchmetrics的使用與上一節課中的AverageMeter+Accuracy函數類似,不過在資料更新維護方面略有不同,並且torchmetrics還有點難理解。

AverageMeter+Accuracy時,是通過self.val, self.sum, self.count, self.avg進行維護。

torchmetrics.Accuracy中,並沒有這些屬性,而是通過tp, tn, fp, fn進行維護。

但是有個問題來了,請仔細觀察代碼,iteration迴圈是3次,每一次batch的數量是10,按道理tp+tn+fp+fn= 30,總共30個樣本,為什麼會是150

因為,這是多類別分類的統計,不是二分類。因此需要為每一個類,單獨計算tp, tn, fp, fn。又因為有5個類別,因此是30*5=150

關於多類別的tp, tn, fp, fn,可參考stackoverflow

<<AI人工智慧 PyTorch自學>> 7.5 torch

還有個好例子,請看混淆矩陣:

真實\預測            0      1      2

 

0                   2      0      0

 

1                   1      0      1

 

2                   0      2      0

Copy

對於類別0 FP=1 TP=2 FN=0 TN=3

對於類別1 FP=2 TP=0 FN=2 TN=2

對於類別2 FP=1 TP=0 FN=2 TN=3

自訂metrics

瞭解了Accuracy使用邏輯,就可以觸類旁通,使用其它80多個Metrics

但總有不滿足業務需求的時候,這時候就需要自訂metrics

自訂metrics非常簡單,它就像自訂Module一樣,提供必備的函數即可。

自訂metrics只需要繼承Metric,然後實現以下三個函數即可:

  • init(): Each state variable should be called using self.add_state(...).
  • update(): Any code needed to update the state given any inputs to the metric.
  • compute(): Computes a final value from the state of the metric.

舉例:

class MyAccuracy(Metric):

    full_state_update: bool = False

 

    def __init__(self):

        super().__init__()

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")

        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

 

    def update(self, preds: torch.Tensor, target: torch.Tensor):

        batch_size = target.size(0)

        _, pred = preds.topk(1, 1, True, True)

        pred = pred.t()

        correct = pred.eq(target.reshape(1, -1).expand_as(pred))

        self.correct += torch.sum(correct)

        self.total += batch_size

 

    def compute(self):

        return self.correct.float() / self.total

Copy

這裡需要注意的是:

  • init函數中需要通過add_state進行屬性初始化;
  • update中需要處理接收的資料,並可自訂管理機制,如這裡採用correcttotal來管理總的資料
  • compute中需清晰知道返回的是總數據的Accuracy

小結

torchmetrics是一個簡單易用的指標評估庫,裡面提供了80多種指標,建議採用torchmetrics進行指標評估,避免重複造輪子。

下面請看支持的指標:

Auido 任務指標

  • Perceptual Evaluation of Speech Quality (PESQ)
  • Permutation Invariant Training (PIT)
  • Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)
  • Scale-Invariant Signal-to-Noise Ratio (SI-SNR)
  • Short-Time Objective Intelligibility (STOI)
  • Signal to Distortion Ratio (SDR)
  • Signal-to-Noise Ratio (SNR)

分類 任務指標

  • Accuracy
  • AUC
  • AUROC
  • Average Precision
  • Binned Average Precision
  • Binned Precision Recall Curve
  • Binned Recall At Fixed Precision
  • Calibration Error
  • Cohen Kappa
  • Confusion Matrix
  • Coverage Error
  • Dice Score
  • F1 Score
  • FBeta Score
  • Hamming Distance
  • Hinge Loss
  • Jaccard Index
  • KL Divergence
  • Label Ranking Average Precision
  • Label Ranking Loss
  • Matthews Corr. Coef.
  • Precision
  • Precision Recall
  • Precision Recall Curve
  • Recall
  • ROC
  • Specificity
  • Stat Scores

圖像 任務指標

  • Error Relative Global Dim. Synthesis (ERGAS)
  • Frechet Inception Distance (FID)
  • Image Gradients
  • Inception Score
  • Kernel Inception Distance
  • Learned Perceptual Image Patch Similarity (LPIPS)
  • Multi-Scale SSIM
  • Peak Signal-to-Noise Ratio (PSNR)
  • Spectral Angle Mapper
  • Spectral Distortion Index
  • Structural Similarity Index Measure (SSIM)
  • Universal Image Quality Index

檢測 任務指標

  • Mean-Average-Precision (mAP)

Pairwise 任務指標

  • Cosine Similarity
  • Euclidean Distance
  • Linear Similarity
  • Manhattan Distance

Regression 任務指標

  • Cosine Similarity
  • Explained Variance
  • Mean Absolute Error (MAE)
  • Mean Absolute Percentage Error (MAPE)
  • Mean Squared Error (MSE)
  • Mean Squared Log Error (MSLE)
  • Pearson Corr. Coef.
  • R2 Score
  • Spearman Corr. Coef.
  • Symmetric Mean Absolute Percentage Error (SMAPE)
  • Tweedie Deviance Score
  • Weighted MAPE

Retrieval 任務指標

  • Retrieval Fall-Out
  • Retrieval Hit Rate
  • Retrieval Mean Average Precision (MAP)
  • Retrieval Mean Reciprocal Rank (MRR)
  • Retrieval Normalized DCG
  • Retrieval Precision
  • Retrieval R-Precision
  • Retrieval Recall

Text 任務指標

  • BERT Score
  • BLEU Score
  • Char Error Rate
  • ChrF Score
  • Extended Edit Distance
  • Match Error Rate
  • ROUGE Score
  • Sacre BLEU Score
  • SQuAD
  • Translation Edit Rate (TER)
  • Word Error Rate
  • Word Info. LostWord Info. Preserved

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()