4.5 hook函數

注:本小節主要參考PyTorch模型訓練實用教程》(第一版),主要更新了PyTorch新版本的函數——torch.nn.Module.register_full_backward_hook

-------------------------------------------------分割線---------------------------------------------------------------

本小節將介紹Module中的三個Hook函數以及Tensor的一個Hook函數

  • torch.Tensor.register_hook
  • torch.nn.Module.register_forward_hook
  • torch.nn.Module.register_forward_pre_hook
  • torch.nn.Module.register_full_backward_hook

同時使用hook函數優雅地實現Grad-CAM,效果如下圖所示:

​ <<AI人工智慧 PyTorch自學>> 4.5 hook函

Grad-CAMCAM(class activation map,類啟動圖)的改進,可對任意結構的CNN進行類啟動視覺化,不需要修改網路結構或者重新訓練,詳細理論請參見Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

什麼是hook

Hook函數在多門程式設計語言中均有出現,是一個經典的程式設計方式。hook意為鉤、掛鉤、魚鉤。 引用知乎用戶馬索萌hook的解釋:“(hook)相當於外掛程式。可以實現一些額外的功能,而又不用修改主體代碼。把這些額外功能實現了掛在主代碼上,所以叫鉤子,很形象。

簡單講,就是不修改主體,而實現額外功能。對應到在pytorch中,主體就是forwardbackward,而額外的功能就是對模型的變數進行操作,如提取特徵圖,提取非葉子張量的梯度,修改張量梯度等等。

hook的出現與pytorch運算機制有關,pytorch在每一次運算結束後,會將中間變數釋放,以節省記憶體空間,這些會被釋放的變數包括非葉子張量的梯度,中間層的特徵圖等。

但有時候,想視覺化中間層的特徵圖,又不能改動模型主體代碼,該怎麼辦呢?這時候就要用到hook了。 舉個例子演示hook提取非葉子張量的梯度:

import torch

def grad_hook(grad):

    y_grad.append(grad)

y_grad = list()

x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)

y = x+1

y.register_hook(grad_hook)

z = torch.mean(y*y)

z.backward()

print("type(y): ", type(y))

print("y.grad: ", y.grad)

print("y_grad[0]: ", y_grad[0])

 

>>> ('type(y): ', <class 'torch.Tensor'>)

>>> ('y.grad: ', None)

>>> ('y_grad[0]: ', tensor([[1.0000, 1.5000],

        [2.0000, 2.5000]]))

Copy

可以看到y.grad的值為None,這是因為y是非葉子結點張量,在z.backward()完成之後,y的梯度被釋放掉以節省記憶體,但可以通過torch.Tensor的類方法register_hooky的梯度提取出來。

torch.Tensor.register_hook

torch.Tensor.register_hook (Python method, in torch.Tensor.register_hook)

功能:註冊一個反向傳播hook函數,這個函數是Tensor類裡的,當計算tensor的梯度時自動執行。

為什麼是backward?因為這個hook是針對tensor的,tensor中的什麼東西會在計算結束後釋放? 那就是gradient,所以是backward hook.

形式 hook(grad) -> Tensor or None ,其中grad就是這個tensor的梯度。

返回值a handle that can be used to remove the added hook by calling handle.remove()

應用場景舉例:在hook函數中可對梯度grad進行in-place操作,即可修改tensorgrad值。 這是一個很酷的功能,例如當淺層的梯度消失時,可以對淺層的梯度乘以一定的倍數,用來增大梯度; 還可以對梯度做截斷,限制梯度在某一區間,防止過大的梯度對權值參數進行修改。 下面舉兩個例子,例1是如何獲取中間變數y的梯度,例2是利用hook函數將變數x的梯度擴大2倍。

1

import torch

y_grad = list()

def grad_hook(grad):

    y_grad.append(grad)

x = torch.tensor([2., 2., 2., 2.], requires_grad=True)

y = torch.pow(x, 2)

z = torch.mean(y)

h = y.register_hook(grad_hook)

z.backward()

print("y.grad: ", y.grad)

print("y_grad[0]: ", y_grad[0])

h.remove()    # removes the hook

 

>>> ('y.grad: ', None)

>>> ('y_grad[0]: ', tensor([0.2500, 0.2500, 0.2500, 0.2500]))

Copy

可以看到當z.backward()結束後,張量y中的gradNone,因為y是非葉子節點張量,在梯度反傳結束之後,被釋放。 在對張量yhook函數(grad_hook)中,將y的梯度保存到了y_grad列表中,因此可以在z.backward()結束後,仍舊可以在y_grad[0]中讀到y的梯度為tensor([0.2500, 0.2500, 0.2500, 0.2500])

2

import torch

def grad_hook(grad):

    grad *= 2

x = torch.tensor([2., 2., 2., 2.], requires_grad=True)

y = torch.pow(x, 2)

z = torch.mean(y)

h = x.register_hook(grad_hook)

z.backward()

print(x.grad)

h.remove()    # removes the hook

 

>>> tensor([2., 2., 2., 2.])

Copy

x的梯度為tensor([1., 1., 1., 1.]),經grad_hook操作後,梯度為tensor([2., 2., 2., 2.])

torch.nn.Module.register_forward_hook

功能Module前向傳播中的hook,module在前向傳播後,自動調用hook函數。 形式:hook(module, input, output) -> None or modified output 。注意不能修改inputoutput

返回值a handle that can be used to remove the added hook by calling handle.remove()

舉例:假設網路由卷積層conv1和池化層pool1構成,輸入一張4*4的圖片,現採用forward_hook獲取module——conv1之後的feature maps,示意圖如下:

​ <<AI人工智慧 PyTorch自學>> 4.5 hook函

import torch

import torch.nn as nn

class Net(nn.Module):

    def __init__(self):

        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 2, 3)

        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):

        x = self.conv1(x)

        x = self.pool1(x)

        return x

def farward_hook(module, data_input, data_output):

    fmap_block.append(data_output)

    input_block.append(data_input)

if __name__ == "__main__":

    # 初始化網路

    net = Net()

    net.conv1.weight[0].fill_(1)

    net.conv1.weight[1].fill_(2)

    net.conv1.bias.data.zero_()

    # 註冊hook

    fmap_block = list()

    input_block = list()

    net.conv1.register_forward_hook(farward_hook)

    # inference

    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W

    output = net(fake_img)

    # 觀察

    print("output shape: {}\noutput value: {}\n".format(output.shape, output))

    print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))

    print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

Copy

首先初始化一個網路,卷積層有兩個卷積核,權重分別為全1和全2bias設置為0,池化層採用2*2的最大池化。

在進行forward之前對module——conv1註冊了forward_hook函數,然後執行前向傳播(output=net(fake_img)),當前向傳播完成後, fmap_block清單中的第一個元素就是conv1層輸出的特徵圖了。

這裡注意觀察forward_hook函數有data_inputdata_output兩個變數,特徵圖是data_output這個變數,而data_inputconv1層的輸入資料, conv1層的輸入是一個tuple的形式。

hook函式呼叫邏輯

下面剖析一下module是怎麼樣調用hook函數的呢?

  1. output = net(fakeimg) net是一個module類,對module執行 module(input)是會調用module._call
  2. module.call :會進入_call_impl,回顧Module那一小節,call_impl是有很多其它代碼,這就是對hook函數的處理,可以看到,讓註冊了hook函數,模型的forward不再是4.1小節裡分析的1102行代碼進行,而是分別執行對應的hook函數。1109行是執行每個forward_pre_hook的,1120行是執行forward的,1123行是執行forward_hook的, 1144行是執行full_backward_hook的。

def _call_impl(self, *input, **kwargs):

    forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)

    # If we don't have any hooks, we want to skip the rest of the logic in

    # this function, and just call forward.

    if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks

            or _global_forward_hooks or _global_forward_pre_hooks):

        return forward_call(*input, **kwargs)

    # Do not call functions when jit is used

    full_backward_hooks, non_full_backward_hooks = [], []

    if self._backward_hooks or _global_backward_hooks:

        full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()

    if _global_forward_pre_hooks or self._forward_pre_hooks:

        for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):

            result = hook(self, input)

            if result is not None:

                if not isinstance(result, tuple):

                    result = (result,)

                input = result

 

    bw_hook = None

    if full_backward_hooks:

        bw_hook = hooks.BackwardHook(self, full_backward_hooks)

        input = bw_hook.setup_input_hook(input)

 

    result = forward_call(*input, **kwargs)

    if _global_forward_hooks or self._forward_hooks:

        for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

            hook_result = hook(self, input, result)

            if hook_result is not None:

                result = hook_result

 

    if bw_hook:

        result = bw_hook.setup_output_hook(result)

 

    # Handle the non-full backward hooks

    if non_full_backward_hooks:

        var = result

        while not isinstance(var, torch.Tensor):

            if isinstance(var, dict):

                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))

            else:

                var = var[0]

        grad_fn = var.grad_fn

        if grad_fn is not None:

            for hook in non_full_backward_hooks:

                wrapper = functools.partial(hook, self)

                functools.update_wrapper(wrapper, hook)

                grad_fn.register_hook(wrapper)

            self._maybe_warn_non_full_backward_hook(input, result, grad_fn)

 

    return result

Copy

這裡需要注意兩點:

  1. hook_result = hook(self, input, result)中的inputresult不可以修改。這裡的input對應forward_hook函數中的data_inputresult對應forward_hook函數中的data_output,在conv1中,input就是該層的輸入資料,result就是經過conv1層操作之後的輸出特徵圖。雖然可以通過hook來對這些資料操作,但是不能修改這些值,否則會破壞模型的計算。
  2. 註冊的hook函數是不能帶返回值的,否則拋出異常,這個可以從代碼中看到 if hook_result is not None: raise RuntimeError

總結一下調用流程: net(fake_img) --> net.call : result = self.forward(input, *kwargs) --> net.forward: x = self.conv1(x) --> conv1.call:hook_result = hook(self, input, result) hook就是註冊了的forward_hook函數。

torch.nn.Module.register_forward_pre_hook

功能:執行forward()之前調用hook函數。 形式:hook(module, input) -> None or modified input

應用場景register_forward_pre_hookforward_hook一樣,是在module.call中註冊的,與forward_hook不同的是,其在module執行forward之前就運行了,具體可看module.call中的代碼。

torch.nn.Module.register_full_backward_hook

功能Module反向傳播中的hook,每次計算module的梯度後,自動調用hook函數。 形式:hook(module, grad_input, grad_output) -> tuple(Tensor) or None

注意事項

  • module有多個輸入或輸出時,grad_inputgrad_output是一個tuple
  • register_full_backward_hook 是修改過的版本,舊版本為register_backward_hook,不過官方已經建議棄用,不需要再瞭解。

返回值a handle that can be used to remove the added hook by calling handle.remove()

應用場景舉例:提取特徵圖的梯度

Grad-CAM 實現

採用register_full_backward_hook實現特徵圖梯度的提取,並結合Grad-CAM(基於類梯度的類啟動圖視覺化)方法對卷積神經網路的學習模式進行視覺化。

關於Grad-CAM請看論文:《Grad-CAM Visual Explanations from Deep Networks via Gradient-based Localization 簡單介紹Grad-CAM的操作,Grad-CAM通過對最後一層特徵圖進行加權求和得到heatmap,整個CAM系列的主要研究就在於這個加權求和中的權值從那裡來。

Grad-CAM是對特徵圖求梯度,將每一張特徵圖的梯度求平均得到權值(特徵圖的梯度是element-wise的)。求梯度時並不採用網路的輸出,而是採用類向量,即one-hot向量。

下圖是ResNetGrad-CAM示意圖,上圖類向量採用的是貓的標籤,下圖採用的是狗的標籤,可以看到在上圖模型更關注貓(紅色部分),下圖判別為狗的主要依據是狗的頭部。

​ 

<<AI人工智慧 PyTorch自學>> 4.5 hook函

 

下面採用一個LeNet-5演示backward_hookGrad-CAM中的應用。 簡述代碼過程:

  1. 創建網路net
  2. 註冊forward_hook函數用於提取最後一層特徵圖;
  3. 註冊backward_hook函數用於提取類向量(one-hot)關於特徵圖的梯度
  4. 對特徵圖的梯度進行求均值,並對特徵圖進行加權;
  5. 視覺化heatmap

注意:需要注意的是在backward_hook函數中,grad_out是一個tuple類型的,要取得特徵圖的梯度需要這樣grad_block.append(grad_out[0].detach())

思考

這裡對3張飛機的圖片進行觀察heatmap,如下圖所示,第一行是原圖,第二行是疊加了heatmap的圖片。

這裡發現一個有意思的現象,模型將圖片判為飛機的依據是藍天,而不是飛機(圖1-3)。 那麼我們喂給模型一張純天藍色的圖片,模型會判為什麼呢?如圖4所示,發現模型判為了飛機

從這裡發現,雖然能將飛機正確分類,但是它學到的卻不是飛機的特徵! 這導致模型的泛化性能大打折扣,從這裡我們可以考慮採用trick讓模型強制的學習到飛機而不是常與飛機一同出現的藍天,或者是調整資料。

​ <<AI人工智慧 PyTorch自學>> 4.5 hook函

對於圖4疑問:heatmap藍色區域是否對圖像完全不起作用呢?是否僅僅通過紅色區域就可以對圖像進行判別呢? 接下來將一輛正確分類的汽車圖片(圖5)疊加到圖4藍色回應區域(即模型並不關注的區域),結果如圖6所示,汽車部分的回應值很小,模型仍通過天藍色區域將圖片判為了飛機。 接著又將汽車疊加到圖4紅色回應區域(圖的右下角),結果如圖7所示,仍將圖片判為了飛機。 有意思的是將汽車疊加到圖7的紅色回應區域,模型把圖片判為了船,而且紅色回應區域是藍色區域的下部分,這個與船在大海中的位置很接近。

​ <<AI人工智慧 PyTorch自學>> 4.5 hook函

通過以上代碼學習full_backward_hook的使用及其在Grad-CAM中的應用,並通過Grad-CAM能診斷模型是否學習到關鍵特徵。 關於CAM( class activation maping,類啟動響應圖)是一個很有趣的研究,有興趣的朋友可以對CAMGrad-CAMGrad-CAM++進行研究。

小結

本小節介紹了程式設計語言中經典的思想——Hook函數,並講解了pytorch中如何使用它們,最後還採用full_backward_hook實現有趣的Grad-CAM視覺化,本節代碼較多,建議對著配套代碼單步調試進行學習,掌握hook函數的妙用,在今後使用pytorch進行模型分析、魔改的時候更遊刃有餘。

下一小結會把本章所學的Module相關容器、網路層的知識點串起來使用,通過剖析torchvision中經典模型的原始程式碼,瞭解所學習的知識點是如何使用的。

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()