4.5 hook函數
注:本小節主要參考《PyTorch模型訓練實用教程》(第一版),主要更新了PyTorch新版本的函數——torch.nn.Module.register_full_backward_hook。
-------------------------------------------------分割線---------------------------------------------------------------
本小節將介紹Module中的三個Hook函數以及Tensor的一個Hook函數
- torch.Tensor.register_hook
- torch.nn.Module.register_forward_hook
- torch.nn.Module.register_forward_pre_hook
- torch.nn.Module.register_full_backward_hook
同時使用hook函數優雅地實現Grad-CAM,效果如下圖所示:
Grad-CAM是CAM(class activation map,類啟動圖)的改進,可對任意結構的CNN進行類啟動視覺化,不需要修改網路結構或者重新訓練,詳細理論請參見Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
什麼是hook?
Hook函數在多門程式設計語言中均有出現,是一個經典的程式設計方式。hook意為鉤、掛鉤、魚鉤。 引用知乎用戶“馬索萌”對hook的解釋:“(hook)相當於外掛程式。可以實現一些額外的功能,而又不用修改主體代碼。把這些額外功能實現了掛在主代碼上,所以叫鉤子,很形象。”
簡單講,就是不修改主體,而實現額外功能。對應到在pytorch中,主體就是forward和backward,而額外的功能就是對模型的變數進行操作,如“提取”特徵圖,“提取”非葉子張量的梯度,修改張量梯度等等。
hook的出現與pytorch運算機制有關,pytorch在每一次運算結束後,會將中間變數釋放,以節省記憶體空間,這些會被釋放的變數包括非葉子張量的梯度,中間層的特徵圖等。
但有時候,想視覺化中間層的特徵圖,又不能改動模型主體代碼,該怎麼辦呢?這時候就要用到hook了。 舉個例子演示hook提取非葉子張量的梯度:
import torch
def grad_hook(grad):
y_grad.append(grad)
y_grad = list()
x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
y = x+1
y.register_hook(grad_hook)
z = torch.mean(y*y)
z.backward()
print("type(y): ", type(y))
print("y.grad: ", y.grad)
print("y_grad[0]: ", y_grad[0])
>>> ('type(y): ', <class 'torch.Tensor'>)
>>> ('y.grad: ', None)
>>> ('y_grad[0]: ', tensor([[1.0000, 1.5000],
[2.0000, 2.5000]]))
Copy
可以看到y.grad的值為None,這是因為y是非葉子結點張量,在z.backward()完成之後,y的梯度被釋放掉以節省記憶體,但可以通過torch.Tensor的類方法register_hook將y的梯度提取出來。
torch.Tensor.register_hook
torch.Tensor.register_hook (Python method, in torch.Tensor.register_hook)
功能:註冊一個反向傳播hook函數,這個函數是Tensor類裡的,當計算tensor的梯度時自動執行。
為什麼是backward?因為這個hook是針對tensor的,tensor中的什麼東西會在計算結束後釋放? 那就是gradient,所以是backward hook.
形式: hook(grad) -> Tensor or None ,其中grad就是這個tensor的梯度。
返回值:a handle that can be used to remove the added hook by calling handle.remove()
應用場景舉例:在hook函數中可對梯度grad進行in-place操作,即可修改tensor的grad值。 這是一個很酷的功能,例如當淺層的梯度消失時,可以對淺層的梯度乘以一定的倍數,用來增大梯度; 還可以對梯度做截斷,限制梯度在某一區間,防止過大的梯度對權值參數進行修改。 下面舉兩個例子,例1是如何獲取中間變數y的梯度,例2是利用hook函數將變數x的梯度擴大2倍。
例1:
import torch
y_grad = list()
def grad_hook(grad):
y_grad.append(grad)
x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
h = y.register_hook(grad_hook)
z.backward()
print("y.grad: ", y.grad)
print("y_grad[0]: ", y_grad[0])
h.remove() # removes the hook
>>> ('y.grad: ', None)
>>> ('y_grad[0]: ', tensor([0.2500, 0.2500, 0.2500, 0.2500]))
Copy
可以看到當z.backward()結束後,張量y中的grad為None,因為y是非葉子節點張量,在梯度反傳結束之後,被釋放。 在對張量y的hook函數(grad_hook)中,將y的梯度保存到了y_grad列表中,因此可以在z.backward()結束後,仍舊可以在y_grad[0]中讀到y的梯度為tensor([0.2500, 0.2500, 0.2500, 0.2500])
例2:
import torch
def grad_hook(grad):
grad *= 2
x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
h = x.register_hook(grad_hook)
z.backward()
print(x.grad)
h.remove() # removes the hook
>>> tensor([2., 2., 2., 2.])
Copy
原x的梯度為tensor([1., 1., 1., 1.]),經grad_hook操作後,梯度為tensor([2., 2., 2., 2.])。
torch.nn.Module.register_forward_hook
功能:Module前向傳播中的hook,module在前向傳播後,自動調用hook函數。 形式:hook(module, input, output) -> None or modified output 。注意不能修改input和output
返回值:a handle that can be used to remove the added hook by calling handle.remove()
舉例:假設網路由卷積層conv1和池化層pool1構成,輸入一張4*4的圖片,現採用forward_hook獲取module——conv1之後的feature maps,示意圖如下:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
def farward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
if __name__ == "__main__":
# 初始化網路
net = Net()
net.conv1.weight[0].fill_(1)
net.conv1.weight[1].fill_(2)
net.conv1.bias.data.zero_()
# 註冊hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(farward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
# 觀察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
Copy
首先初始化一個網路,卷積層有兩個卷積核,權重分別為全1和全2,bias設置為0,池化層採用2*2的最大池化。
在進行forward之前對module——conv1註冊了forward_hook函數,然後執行前向傳播(output=net(fake_img)),當前向傳播完成後, fmap_block清單中的第一個元素就是conv1層輸出的特徵圖了。
這裡注意觀察forward_hook函數有data_input和data_output兩個變數,特徵圖是data_output這個變數,而data_input是conv1層的輸入資料, conv1層的輸入是一個tuple的形式。
hook函式呼叫邏輯
下面剖析一下module是怎麼樣調用hook函數的呢?
- output = net(fakeimg) net是一個module類,對module執行 module(input)是會調用module._call
- module.call :會進入_call_impl,回顧Module那一小節,call_impl是有很多其它代碼,這就是對hook函數的處理,可以看到,讓註冊了hook函數,模型的forward不再是4.1小節裡分析的1102行代碼進行,而是分別執行對應的hook函數。1109行是執行每個forward_pre_hook的,1120行是執行forward的,1123行是執行forward_hook的, 1144行是執行full_backward_hook的。
def _call_impl(self, *input, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*input, **kwargs)
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if self._backward_hooks or _global_backward_hooks:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
if _global_forward_pre_hooks or self._forward_pre_hooks:
for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if full_backward_hooks:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
result = forward_call(*input, **kwargs)
if _global_forward_hooks or self._forward_hooks:
for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if non_full_backward_hooks:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
Copy
這裡需要注意兩點:
- hook_result = hook(self, input, result)中的input和result不可以修改。這裡的input對應forward_hook函數中的data_input,result對應forward_hook函數中的data_output,在conv1中,input就是該層的輸入資料,result就是經過conv1層操作之後的輸出特徵圖。雖然可以通過hook來對這些資料操作,但是不能修改這些值,否則會破壞模型的計算。
- 註冊的hook函數是不能帶返回值的,否則拋出異常,這個可以從代碼中看到 if hook_result is not None: raise RuntimeError
總結一下調用流程: net(fake_img) --> net.call : result = self.forward(input, *kwargs) --> net.forward: x = self.conv1(x) --> conv1.call:hook_result = hook(self, input, result) hook就是註冊了的forward_hook函數。
torch.nn.Module.register_forward_pre_hook
功能:執行forward()之前調用hook函數。 形式:hook(module, input) -> None or modified input
應用場景:register_forward_pre_hook與forward_hook一樣,是在module.call中註冊的,與forward_hook不同的是,其在module執行forward之前就運行了,具體可看module.call中的代碼。
torch.nn.Module.register_full_backward_hook
功能:Module反向傳播中的hook,每次計算module的梯度後,自動調用hook函數。 形式:hook(module, grad_input, grad_output) -> tuple(Tensor) or None
注意事項:
- 當module有多個輸入或輸出時,grad_input和grad_output是一個tuple。
- register_full_backward_hook 是修改過的版本,舊版本為register_backward_hook,不過官方已經建議棄用,不需要再瞭解。
返回值:a handle that can be used to remove the added hook by calling handle.remove()
應用場景舉例:提取特徵圖的梯度
Grad-CAM 實現
採用register_full_backward_hook實現特徵圖梯度的提取,並結合Grad-CAM(基於類梯度的類啟動圖視覺化)方法對卷積神經網路的學習模式進行視覺化。
關於Grad-CAM請看論文:《Grad-CAM Visual Explanations from Deep Networks via Gradient-based Localization》 簡單介紹Grad-CAM的操作,Grad-CAM通過對最後一層特徵圖進行加權求和得到heatmap,整個CAM系列的主要研究就在於這個加權求和中的權值從那裡來。
Grad-CAM是對特徵圖求梯度,將每一張特徵圖的梯度求平均得到權值(特徵圖的梯度是element-wise的)。求梯度時並不採用網路的輸出,而是採用類向量,即one-hot向量。
下圖是ResNet的Grad-CAM示意圖,上圖類向量採用的是貓的標籤,下圖採用的是狗的標籤,可以看到在上圖模型更關注貓(紅色部分),下圖判別為狗的主要依據是狗的頭部。
下面採用一個LeNet-5演示backward_hook在Grad-CAM中的應用。 簡述代碼過程:
- 創建網路net;
- 註冊forward_hook函數用於提取最後一層特徵圖;
- 註冊backward_hook函數用於提取類向量(one-hot)關於特徵圖的梯度
- 對特徵圖的梯度進行求均值,並對特徵圖進行加權;
- 視覺化heatmap。
注意:需要注意的是在backward_hook函數中,grad_out是一個tuple類型的,要取得特徵圖的梯度需要這樣grad_block.append(grad_out[0].detach())
思考
這裡對3張飛機的圖片進行觀察heatmap,如下圖所示,第一行是原圖,第二行是疊加了heatmap的圖片。
這裡發現一個有意思的現象,模型將圖片判為飛機的依據是藍天,而不是飛機(圖1-3)。 那麼我們喂給模型一張純天藍色的圖片,模型會判為什麼呢?如圖4所示,發現模型判為了飛機
從這裡發現,雖然能將飛機正確分類,但是它學到的卻不是飛機的特徵! 這導致模型的泛化性能大打折扣,從這裡我們可以考慮採用trick讓模型強制的學習到飛機而不是常與飛機一同出現的藍天,或者是調整資料。
對於圖4疑問:heatmap藍色區域是否對圖像完全不起作用呢?是否僅僅通過紅色區域就可以對圖像進行判別呢? 接下來將一輛正確分類的汽車圖片(圖5)疊加到圖4藍色回應區域(即模型並不關注的區域),結果如圖6所示,汽車部分的回應值很小,模型仍通過天藍色區域將圖片判為了飛機。 接著又將汽車疊加到圖4紅色回應區域(圖的右下角),結果如圖7所示,仍將圖片判為了飛機。 有意思的是將汽車疊加到圖7的紅色回應區域,模型把圖片判為了船,而且紅色回應區域是藍色區域的下部分,這個與船在大海中的位置很接近。
通過以上代碼學習full_backward_hook的使用及其在Grad-CAM中的應用,並通過Grad-CAM能診斷模型是否學習到關鍵特徵。 關於CAM( class activation maping,類啟動響應圖)是一個很有趣的研究,有興趣的朋友可以對CAM、Grad-CAM和Grad-CAM++進行研究。
小結
本小節介紹了程式設計語言中經典的思想——Hook函數,並講解了pytorch中如何使用它們,最後還採用full_backward_hook實現有趣的Grad-CAM視覺化,本節代碼較多,建議對著配套代碼單步調試進行學習,掌握hook函數的妙用,在今後使用pytorch進行模型分析、魔改的時候更遊刃有餘。
下一小結會把本章所學的Module相關容器、網路層的知識點串起來使用,通過剖析torchvision中經典模型的原始程式碼,瞭解所學習的知識點是如何使用的。
留言列表