6.5 模型參數視覺化
隨著神經網路越來越深,越來越複雜,手動計算模型中間的資料的shape變得困難。
本節將介紹torchinfo,可用一鍵實現模型參數量計算、各層特徵圖形狀計算和計算量計算等功能。
torchinfo的功能最早來自於TensorFlow和Kearas的summary()函數,torchinfo是學習借鑒而來。而在torchinfo之前還有torchsummary工具,不過torchsummary已經停止更新,並且推薦使用torchinfo。
torchsummay:https://github.com/sksq96/pytorch-summary
torchinfo:https://github.com/TylerYep/torchinfo
torchinfo 主要提供了一個函數,即
def summary(
model: nn.Module,
input_size: Optional[INPUT_SIZE_TYPE] = None,
input_data: Optional[INPUT_DATA_TYPE] = None,
batch_dim: Optional[int] = None,
cache_forward_pass: Optional[bool] = None,
col_names: Optional[Iterable[str]] = None,
col_width: int = 25,
depth: int = 3,
device: Optional[torch.device] = None,
dtypes: Optional[List[torch.dtype]] = None,
mode: str | None = None,
row_settings: Optional[Iterable[str]] = None,
verbose: int = 1,
**kwargs: Any,
) -> ModelStatistics:
torchinfo 演示
運行代碼
resnet_50 = models.resnet50(pretrained=
False)
batch_size =
1
summary(resnet_50, input_size=(batch_size,
3,
224,
224))
可看到resnet50的以下資訊
==========================================================================================
Layer (type:depth-idx) Output Shape Param
#
==========================================================================================
ResNet [
1,
1000] --
├─
Conv2d:
1-1[
1,
64,
112,
112]
9,
408
├─
BatchNorm2d:
1-2[
1,
64,
112,
112]
128
├─
ReLU:
1-3[
1,
64,
112,
112] --
├─
MaxPool2d:
1-4[
1,
64,
56,
56] --
├─
Sequential:
1-5[
1,
256,
56,
56] --
│
└─
Bottleneck:
2-1[
1,
256,
56,
56] --
│
│
└─
Conv2d:
3-1[
1,
64,
56,
56]
4,
096
│
│
└─
BatchNorm2d:
3-2[
1,
64,
56,
56]
128
│
│
└─
ReLU:
3-3[
1,
64,
56,
56] --
│
│
└─
Conv2d:
3-4[
1,
64,
56,
56]
36,
864
│
│
└─
BatchNorm2d:
3-5[
1,
64,
56,
56]
128
│
│
└─
ReLU:
3-6[
1,
64,
56,
56] --
│
│
└─
Conv2d:
3-7[
1,
256,
56,
56]
16,
384
│
│
└─
BatchNorm2d:
3-8[
1,
256,
56,
56]
512
│
│
└─
Sequential:
3-9[
1,
256,
56,
56]
16,
896
│
│
└─
ReLU:
3-10[
1,
256,
56,
56] --
......
│
└─
Bottleneck:
2-16[
1,
2048,
7,
7] --
│
│
└─
Conv2d:
3-140[
1,
512,
7,
7]
1,
048,
576
│
│
└─
BatchNorm2d:
3-141[
1,
512,
7,
7]
1,
024
│
│
└─
ReLU:
3-142[
1,
512,
7,
7] --
│
│
└─
Conv2d:
3-143[
1,
512,
7,
7]
2,
359,
296
│
│
└─
BatchNorm2d:
3-144[
1,
512,
7,
7]
1,
024
│
│
└─
ReLU:
3-145[
1,
512,
7,
7] --
│
│
└─
Conv2d:
3-146[
1,
2048,
7,
7]
1,
048,
576
│
│
└─
BatchNorm2d:
3-147[
1,
2048,
7,
7]
4,
096
│
│
└─
ReLU:
3-148[
1,
2048,
7,
7] --
├─
AdaptiveAvgPool2d:
1-9[
1,
2048,
1,
1] --
├─
Linear:
1-10[
1,
1000]
2,
049,
000
==========================================================================================
Total params:
25,
557,
032
Trainable params:
25,
557,
032
Non-trainable params:
0
Total mult-adds (G):
4.09
==========================================================================================
Input size (MB):
0.60
Forward/backward
passsize (MB):
177.83
Params size (MB):
102.23
Estimated Total Size (MB):
280.66
==========================================================================================
其中包括各網路層名稱,以及層級關係,各網路層輸出形狀以及參數量。在最後還有模型的總結,包括總的參數量有25,557,032個,總的乘加(Mult-Adds)操作有4.09G(4.09*10^9次方 浮點運算),輸入大小為0.60MB,參數占102.23MB。
計算量:1G表示10^9 次浮點運算 (Giga Floating-point Operations Per Second),關於乘加運算,可參考知乎問題
存儲量:這裡的Input size (MB): 0.60,是通過資料精度計算得到,預設情況下採用float32位元存儲一個數,因此輸入為:3*224*224*32b = 4816896b = 602112B = 602.112 KB = 0.6 MB
同理,Params size (MB): 25557032 * 32b = 817,825,024 b = 102,228,128 B = 102.23 MB
介面詳解
summary提供了很多參數可以配置列印資訊,這裡介紹幾個常用參數。
col_names:可選擇列印的資訊內容,如 ("input_size","output_size","num_params","kernel_size","mult_adds","trainable",)
dtypes:可以設置資料類型,默認的為float32,單精確度。
mode:可設置模型在訓練還是測試狀態。
verbose: 可設置列印資訊的詳細程度。0是不列印,1是默認,2是將weight和bias也打出來。
小結
本節介紹torchinfo的使用,並分析其參數的計算過程,這裡需要瞭解訓練參數數量、特徵圖參數數量和計算量。其中計算量還有一個好用的工具庫進行計算,這裡作為額外資料供大家學習——PyTorch-OpCounter