7.5 torchmetrics 模型評估指標庫
模型訓練時是通過loss進行好壞的評估,因為我們採用的是loss進行方向傳播。對於人類評判好壞,往往不是通過loss值,而是採用某種評判指標。
在圖像分類任務中常用的有Accuracy(準確率)、Recall(召回率)和Precision(精確度),圖像分割中常用mIoU和Dice係數,目標檢測中常用mAP,由此可見不同任務的評價指標大多不一樣。
常用的指標多達幾十種,本節將介紹torchmetrics工具,它目前提供超過80種評價指標的函數,並且使用起來非常方便,值得學習。
TorchMetrics簡介與安裝
TorchMetrics is a collection of 80+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
- A standardized interface to increase reproducibility
- Reduces Boilerplate
- Distributed-training compatible
- Rigorously tested
- Automatic accumulation over batches
- Automatic synchronization between multiple devices
安裝:
pip install torchmetrics
conda install -c conda-forge torchmetrics
Copy
TorchMetrics 快速上手
torchmetrics 的使用與本章第四節課中介紹的AverageMeter類似,它能夠記錄每一次的資訊,並通過.compute()函數進行匯總計算。
下面通過一個accuracy的例子,剖析torchmetrics的體系結構。
from my_utils import setup_seed
setup_seed(40)
import torch
import torchmetrics
metric = torchmetrics.Accuracy()
n_batches = 3
for i in range(n_batches):
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
acc = metric(preds, target) # 單次計算,並記錄本次資訊。通過維護tp, tn, fp, fn來記錄所有資料
print(f"Accuracy on batch {i}: {acc}")
acc_avg = metric.compute()
print(f"Accuracy on all data: {acc_avg}")
tp, tn, fp, fn = metric.tp, metric.tn, metric.fp, metric.fn
print(tp, tn, fp, fn, sum([tp, tn, fp, fn]))
metric.reset()
Copy
Accuracy on batch 0: 0.30000001192092896
Accuracy on batch 1: 0.10000000149011612
Accuracy on batch 2: 0.20000000298023224
Accuracy on all data: 0.20000000298023224
tensor(6) tensor(96) tensor(24) tensor(24) tensor(150)
Copy
torchmetrics的使用可以分以下三步:
1.創建指標評價器
2.反覆運算中進行"update"或forward,update和forward均可記錄每次資料資訊
3.計算所有資料指標
TorchMetrics代碼結構
這裡提到forward,正是第四章中nn.Module的forward。 TorchMetrics所有指標均繼承了nn.Module,因此可以看到這樣的用法。
acc = metric(preds, target)
Copy
下面進入 torchmetrics\classification\accuracy.py 中觀察 Accuracy到底是什麼。
可以看到Accuracy類只有3個函數,分別是__init__, update, compute,其作用就如上文所述。
再看繼承關係,Accuracy --> StatScores --> Metric --> nn.Module + ABC。
Metric類正如文檔所說“The base Metric class is an abstract base class that are used as the building block for all other Module metrics.”,是torchmetrics所有類的基類,它實現forward函數,因此才有像這樣的調用: acc = metric(preds, target)
Accuracy 更新邏輯
torchmetrics的使用與上一節課中的AverageMeter+Accuracy函數類似,不過在資料更新維護方面略有不同,並且torchmetrics還有點難理解。
AverageMeter+Accuracy時,是通過self.val, self.sum, self.count, self.avg進行維護。
在torchmetrics.Accuracy中,並沒有這些屬性,而是通過tp, tn, fp, fn進行維護。
但是有個問題來了,請仔細觀察代碼,iteration迴圈是3次,每一次batch的數量是10,按道理tp+tn+fp+fn= 30,總共30個樣本,為什麼會是150?
因為,這是多類別分類的統計,不是二分類。因此需要為每一個類,單獨計算tp, tn, fp, fn。又因為有5個類別,因此是30*5=150。
關於多類別的tp, tn, fp, fn,可參考stackoverflow
還有個好例子,請看混淆矩陣:
真實\預測 0 1 2
0 2 0 0
1 1 0 1
2 0 2 0
Copy
對於類別0的 FP=1 TP=2 FN=0 TN=3
對於類別1的 FP=2 TP=0 FN=2 TN=2
對於類別2的 FP=1 TP=0 FN=2 TN=3
自訂metrics
瞭解了Accuracy使用邏輯,就可以觸類旁通,使用其它80多個Metrics。
但總有不滿足業務需求的時候,這時候就需要自訂metrics。
自訂metrics非常簡單,它就像自訂Module一樣,提供必備的函數即可。
自訂metrics只需要繼承Metric,然後實現以下三個函數即可:
- init(): Each state variable should be called using self.add_state(...).
- update(): Any code needed to update the state given any inputs to the metric.
- compute(): Computes a final value from the state of the metric.
舉例:
class MyAccuracy(Metric):
full_state_update: bool = False
def __init__(self):
super().__init__()
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
batch_size = target.size(0)
_, pred = preds.topk(1, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
self.correct += torch.sum(correct)
self.total += batch_size
def compute(self):
return self.correct.float() / self.total
Copy
這裡需要注意的是:
- 在init函數中需要通過add_state進行屬性初始化;
- 在update中需要處理接收的資料,並可自訂管理機制,如這裡採用correct與total來管理總的資料
- 在compute中需清晰知道返回的是總數據的Accuracy
小結
torchmetrics是一個簡單易用的指標評估庫,裡面提供了80多種指標,建議採用torchmetrics進行指標評估,避免重複造輪子。
下面請看支持的指標:
Auido 任務指標
- Perceptual Evaluation of Speech Quality (PESQ)
- Permutation Invariant Training (PIT)
- Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)
- Scale-Invariant Signal-to-Noise Ratio (SI-SNR)
- Short-Time Objective Intelligibility (STOI)
- Signal to Distortion Ratio (SDR)
- Signal-to-Noise Ratio (SNR)
分類 任務指標
- Accuracy
- AUC
- AUROC
- Average Precision
- Binned Average Precision
- Binned Precision Recall Curve
- Binned Recall At Fixed Precision
- Calibration Error
- Cohen Kappa
- Confusion Matrix
- Coverage Error
- Dice Score
- F1 Score
- FBeta Score
- Hamming Distance
- Hinge Loss
- Jaccard Index
- KL Divergence
- Label Ranking Average Precision
- Label Ranking Loss
- Matthews Corr. Coef.
- Precision
- Precision Recall
- Precision Recall Curve
- Recall
- ROC
- Specificity
- Stat Scores
圖像 任務指標
- Error Relative Global Dim. Synthesis (ERGAS)
- Frechet Inception Distance (FID)
- Image Gradients
- Inception Score
- Kernel Inception Distance
- Learned Perceptual Image Patch Similarity (LPIPS)
- Multi-Scale SSIM
- Peak Signal-to-Noise Ratio (PSNR)
- Spectral Angle Mapper
- Spectral Distortion Index
- Structural Similarity Index Measure (SSIM)
- Universal Image Quality Index
檢測 任務指標
- Mean-Average-Precision (mAP)
Pairwise 任務指標
- Cosine Similarity
- Euclidean Distance
- Linear Similarity
- Manhattan Distance
Regression 任務指標
- Cosine Similarity
- Explained Variance
- Mean Absolute Error (MAE)
- Mean Absolute Percentage Error (MAPE)
- Mean Squared Error (MSE)
- Mean Squared Log Error (MSLE)
- Pearson Corr. Coef.
- R2 Score
- Spearman Corr. Coef.
- Symmetric Mean Absolute Percentage Error (SMAPE)
- Tweedie Deviance Score
- Weighted MAPE
Retrieval 任務指標
- Retrieval Fall-Out
- Retrieval Hit Rate
- Retrieval Mean Average Precision (MAP)
- Retrieval Mean Reciprocal Rank (MRR)
- Retrieval Normalized DCG
- Retrieval Precision
- Retrieval R-Precision
- Retrieval Recall
Text 任務指標
- BERT Score
- BLEU Score
- Char Error Rate
- ChrF Score
- Extended Edit Distance
- Match Error Rate
- ROUGE Score
- Sacre BLEU Score
- SQuAD
- Translation Edit Rate (TER)
- Word Error Rate
- Word Info. LostWord Info. Preserved
留言列表