6.5 模型參數視覺化

隨著神經網路越來越深,越來越複雜,手動計算模型中間的資料的shape變得困難。

本節將介紹torchinfo,可用一鍵實現模型參數量計算、各層特徵圖形狀計算和計算量計算等功能。

torchinfo的功能最早來自於TensorFlowKearassummary()函數,torchinfo是學習借鑒而來。而在torchinfo之前還有torchsummary工具,不過torchsummary已經停止更新,並且推薦使用torchinfo

torchsummayhttps://github.com/sksq96/pytorch-summary

torchinfohttps://github.com/TylerYep/torchinfo

torchinfo 主要提供了一個函數,即

def summary(
    model: nn.Module,
    input_size: Optional[INPUT_SIZE_TYPE] = None,
    input_data: Optional[INPUT_DATA_TYPE] = None,
    batch_dim: Optional[int] = None,
    cache_forward_pass: Optional[bool] = None,
    col_names: Optional[Iterable[str]] = None,
    col_width: int = 25,
    depth: int = 3,
    device: Optional[torch.device] = None,
    dtypes: Optional[List[torch.dtype]] = None,
    mode: str | None = None,
    row_settings: Optional[Iterable[str]] = None,
    verbose: int = 1,
    **kwargs: Any,
) -> ModelStatistics:

torchinfo 演示

運行代碼

    resnet_50 = models.resnet50(pretrained=False)
    batch_size = 1
    summary(resnet_50, input_size=(batch_size, 3, 224, 224))

可看到resnet50的以下資訊

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
        └─Conv2d: 3-1                  [1, 64, 56, 56]           4,096
        └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
        └─ReLU: 3-3                    [1, 64, 56, 56]           --
        └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
        └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
        └─ReLU: 3-6                    [1, 64, 56, 56]           --
        └─Conv2d: 3-7                  [1, 256, 56, 56]          16,384
        └─BatchNorm2d: 3-8             [1, 256, 56, 56]          512
        └─Sequential: 3-9              [1, 256, 56, 56]          16,896
        └─ReLU: 3-10                   [1, 256, 56, 56]          --
 
......
 
    └─Bottleneck: 2-16                  [1, 2048, 7, 7]           --
        └─Conv2d: 3-140                [1, 512, 7, 7]            1,048,576
        └─BatchNorm2d: 3-141           [1, 512, 7, 7]            1,024
        └─ReLU: 3-142                  [1, 512, 7, 7]            --
        └─Conv2d: 3-143                [1, 512, 7, 7]            2,359,296
        └─BatchNorm2d: 3-144           [1, 512, 7, 7]            1,024
        └─ReLU: 3-145                  [1, 512, 7, 7]            --
        └─Conv2d: 3-146                [1, 2048, 7, 7]           1,048,576
        └─BatchNorm2d: 3-147           [1, 2048, 7, 7]           4,096
        └─ReLU: 3-148                  [1, 2048, 7, 7]           --
├─AdaptiveAvgPool2d: 1-9                 [1, 2048, 1, 1]           --
├─Linear: 1-10                           [1, 1000]                 2,049,000
==========================================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
Total mult-adds (G): 4.09
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 177.83
Params size (MB): 102.23
Estimated Total Size (MB): 280.66
==========================================================================================

其中包括各網路層名稱,以及層級關係,各網路層輸出形狀以及參數量。在最後還有模型的總結,包括總的參數量有25,557,032個,總的乘加(Mult-Adds)操作有4.09G4.09*10^9次方 浮點運算),輸入大小為0.60MB,參數占102.23MB

計算量1G表示10^9 次浮點運算 Giga Floating-point Operations Per Second,關於乘加運算,可參考知乎問題

存儲量:這裡的Input size (MB): 0.60,是通過資料精度計算得到,預設情況下採用float32位元存儲一個數,因此輸入為:3*224*224*32b = 4816896b = 602112B = 602.112 KB = 0.6 MB

同理,Params size (MB): 25557032 * 32b = 817,825,024 b = 102,228,128 B = 102.23 MB

介面詳解

summary提供了很多參數可以配置列印資訊,這裡介紹幾個常用參數。

col_names:可選擇列印的資訊內容,如 ("input_size","output_size","num_params","kernel_size","mult_adds","trainable",)

dtypes:可以設置資料類型,默認的為float32,單精確度。

mode:可設置模型在訓練還是測試狀態。

verbose: 可設置列印資訊的詳細程度。0是不列印,1是默認,2是將weightbias也打出來。

小結

本節介紹torchinfo的使用,並分析其參數的計算過程,這裡需要瞭解訓練參數數量、特徵圖參數數量和計算量。其中計算量還有一個好用的工具庫進行計算,這裡作為額外資料供大家學習——PyTorch-OpCounter

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()