3.4 transforms

本節分為兩部分,首先介紹pytorch的圖像資料增強函式程式庫——transforms,深入分析它的工作機制,並闡述常用的方法。

transforms簡介

資料增強(Data augmentation)已經成為深度學習時代的常規做法,資料增強目的是為了增加訓練資料的豐富度,讓模型接觸多樣性的資料以增加模型的泛化能力。更多關於資料增強的概念,推薦大家閱讀《動手學》的image-augmentation章節

通常,資料增強可分為線上(online)離線(offline)兩種方式,離線方式指的是在訓練開始之前將資料進行變換,變換後的圖片保存到硬碟當中,線上方式則是在訓練過程中,每一次載入訓練資料時對資料進行變換,以實現讓模型看到的圖片都是增強之後的。實際上,這兩種方法理論上是等價的,一般的框架都採用線上方式的資料增強,pytorchtransforms就是線上方式。後續不做特別說明,資料增強專指線上資料增強。

transforms是廣泛使用的圖像變換庫,包含二十多種基礎方法以及多種組合功能,通常可以用Compose把各方法串聯在一起使用。大多數的transforms類都有對應的 functional transforms ,可供用戶自訂調整。transforms提供的主要是PIL格式和Tensor的變換,並且對於圖像的通道也做了規定,預設情況下一個batch的資料是(B, C, H, W) 形狀的張量。

transforms庫中包含二十多種變換方法,那麼多的方法裡應該如何挑選,以及如何設置參數呢? 這是值得大家仔細思考的地方,資料增強的方向一定是測試資料集中可能存在的情況。

舉個例子,做人臉檢測可以用水準翻轉(如前置相機的鏡像就是水準翻轉),但不宜採用垂直翻轉(這裡指一般業務場景,特殊業務場景有垂直翻轉的人臉就另說)。因為真實應用場景不存在倒轉(垂直翻轉)的人臉,因此在訓練過程選擇資料增強時就不應包含垂直翻轉。

運行機制

在正式介紹transforms的系列方法前,先來瞭解pytorch對資料增強的運行機制,我們繼續通過debug模式在dataloader部分進行調試,觀察一張圖片是如何進行資料增強的。

同樣的,我們回顧2.2小結的COVID-19代碼,在dataloader中設置中斷點,進行debug。這裡有一個小技巧,我們可以到datasetgetitem函數裡設置一個中斷點,因為我們前面知道了圖像的讀取及處理是在datasetgetitem裡,因此可以直接進入dataset,不必在dataloader裡繞圈。當然,前提是需要大家熟悉dataloader的運行機制。

在第48img = self.transform(img)設置中斷點,可以看到self.transform是一個Compose物件,繼續進入self.transform(img)

<<AI人工智慧 PyTorch自學>> 3.4 trans

來到 transforms.py Compose類的 __call__函數:這個函數的邏輯是依次調用compose物件裡的變換方法,從此處也可看出資料是串聯的,上一個方法的輸出是下一個方法輸入,這就要求各個方法之間傳輸的資料物件要一致。繼續單步運行,進入第一個t(img) 第一個tResize

<<AI人工智慧 PyTorch自學>> 3.4 trans

 

來到D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torch\nn\modules\module.pyModule類的_call_impl函數:Module類是pytorch模型、網路層的核心,這個類有1854行代碼,下一章將詳細介紹模型模組以及Module。在這裡我們只需要瞭解Resize這個變換方法是一個Module類,它實際的調用在1102行,進入1102行會來到Resize類的forward方法。

來到 D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\transforms\transforms.pyResize類的forward函數:可以看到此函數僅一行代碼F.resize(img, self.size, self.interpolation, self.max_size, self.antialias),繼續進入它。

來到D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torch\nn\functional.py resize函數:functional模組是對一系列操作的封裝,這裡看到419行,resize功能的實現。繼續進入419行。

來到 D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\transforms\functional_pil.pyresize函數:這裡終於進入到最核心的Resize方法實現了,這個函數裡需要時間縮放的w,h,這裡的計算代碼非常值得大家學習,同時函數進來之後對參數的一系列判斷,也值得借鑒。從此函數可以看到它利用了PIL庫的resize函數對PIL圖像進行resize。最終對圖像resize是這265行代碼:return img.resize(size[::-1], interpolation)

然後依次返回,回到transforms.pyCompose類的call函數,此時 img = t(img)完成了1次對圖像的變換。接著繼續執行for迴圈,把compose中的變換執行完畢,就對圖像做完了變換、增強。

總結一下,一開始採用transforms.Compose把變換的方法包裝起來,放到dataset中;在dataloader依次讀數據時,調用datasetgetitem,每個sample讀取時,會根據compose裡的方法依次地對資料進行變換,以此完成線上資料增強。而具體的transforms方法通常包裝成一個Module類,具體實現會在各functional中。

熟悉此運行機制,有助於大家今後自己編寫資料增強方法,嵌入到自己的工程中。

系列API

通過單步debug,瞭解了transforms運行機制,下面看看transforms庫提供的一系列方法及使用。更全面的方法介紹請直接看官方文檔,官方文檔配備了一個圖解transforms的教程

這裡不再一一展開各方法介紹,只挑選幾個代表性的方法展開講解,其餘方法可以到第一版中閱讀transforms的二十二個方法

在這裡,結合COVID-2019 X光分類場景進行系列API的使用介紹。主要內容包括:

  • 具體變換方法使用:resizeNormalizetotensorFiveCropTenCrop
  • 特殊方法使用:RandomChoiceRandomOrderLambda
  • 自動資料增強:AutoAugmentPolicyAutoAugmentRandAugment

具體變換方法使用

Compose

此類用於包裝一系列的transforms方法,在其內部會通過for迴圈依次調用各個方法。這個在上面的代碼調試過程中已經分析清楚了。

Resize

Resize(size, interpolation=, max_size=None, antialias=None)

功能:支援對PILTensor物件的縮放,關於size的設置有些講究,請結合代碼嘗試int方式與tuple方式的差異。int方式是會根據長寬比等比例的縮放圖像,這個在AlexNet論文中提到先等比例縮放再裁剪出224*224的正方形區域。

ToTensor

功能:將PIL對象或nd.array物件轉換成tensor,並且對數值縮放到[0, 1]之間,並且對通道進行右移。具體地,來看原始程式碼 ...\Lib\site-packages\torchvision\transforms\functional.py 下的to_tensor函數

···python

img = img.permute((2, 0, 1)).contiguous()

if isinstance(img, torch.ByteTensor):

​ return img.to(dtype=default_float_dtype).div(255)

PIL物件的通道進行右移,由原來的(H x W x C)變為了(C x H x W) 接著對數值進行除以255,若是正常的圖像圖元,那麼數值被縮放到了[0, 1]之間。

Normalize

Normalize(mean, std, inplace=False)

功能:對tensor物件進行逐通道的標準化,具體操作為減均值再除以標準差,一般使用imagenet128萬資料R\G\B三通道統計得到的meanstdmean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]。相信大家今後再看到這一組資料就明白它們到底怎麼來的了。

FiveCrop&TenCrop

這兩個方法是AlexNet論文中提及,是一個漲點神器,具體使用方式是一張圖片經過多區域裁剪得到5/10張圖片,同時放到模型進行推理,得到5/10個概率向量,然後取它們的平均/最大/最小得到這一張圖片的概率。

FiveCrop表示對圖片進行上下左右以及中心裁剪,獲得 5 張圖片,並返回一個list,這導致我們需要額外處理它們,使得他們符合其它transforms方法的形式——3D-tensor

TenCrop同理,在FiveCrop的基礎上增加水準鏡像,獲得 10 張圖片,並返回一個 list

它們的使用與普通的transforms有一點區別,需要代碼層面的一些改變,下面就通過具體例子講解它們的注意事項。

代碼

授人以漁:其餘的二十多個不在一一介紹,只需要到官方文檔上查看,並到配套代碼中運行,觀察效果即可。

特殊方法使用

PyTorch 不僅可設置對資料的操作,還可以對這些操作進行隨機選擇、組合,讓資料增強更加靈活。

具體有以下4個方法:

  • Lambda
  • RandomChoice
  • RandomOrder
  • RandomApply

Lambda

功能:可進行自訂的操作,例如上文的FiveCrop中利用lambda很好的處理了上下游transforms資料維度不一致的問題。transforms.Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops]))

RandomChoice

功能:以一定的概率從中選擇一個變換方法執行。

RandomOrder

功能:隨機打亂一串變換方法。

RandomApply

功能:以一定的概率執行這一串變換方法。這與RandomChoice的區別僅在於它將一組變換看成一個選擇單位,RandomChoice是一次選一個,RandomApply是一次選一組(list

具體使用可配合配套代碼

自動資料增強

transforms豐富的變換方法以及靈活的組合函數可以知道,資料增強的策略可以千變萬化,怎樣的策略會更好?Google Brain團隊就針對這個問題,利用它們的鈔能力進行研究,採用RNN網路自動搜索組合策略,尋找較好的資料增強策略,詳細可以看這篇文章AutoAugment: Learning Augmentation Strategies from Data。文章中利用RNN搜索出來的策略,可以在ImagenetCifar-10SVHN三個資料集上達到當時的SOTApytorch中也提供了基於AutoAugment論文的三個資料集的自動資料增強策略,下面一起來學習它們。

AutoAugmentPolicy

通過論文AutoAugment: Learning Augmentation Strategies from Data我們知道它研究出了針對三個資料集的資料增強策略,在pytorch中同樣的提供對應的策略,並設計了AutoAugmentPolicy來指示,直接看原始程式碼,一目了然envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\transforms\autoaugment.py

class AutoAugmentPolicy(Enum):

    """AutoAugment policies learned on different datasets.

    Available policies are IMAGENET, CIFAR10 and SVHN.

    """

    IMAGENET = "imagenet"

    CIFAR10 = "cifar10"

    SVHN = "svhn"

Copy

AutoAugment

torchvision.transforms.AutoAugment(policy: torchvision.transforms.autoaugment.AutoAugmentPolicy = , interpolation: torchvision.transforms.functional.InterpolationMode = , fill: Optional[List[float]] = None)

功能:自動資料增強方法的封裝,支援三種資料增強策略,分別是IMAGENETCIFAR10 SVHN

參數:

policy :需要是AutoAugmentPolicy

interpolation:設置插值方法

fill :設置填充圖元的圖元值,預設為0,黑色。

AutoAugment也是一個Module類,具體的變換操作在forward()函數中體現,建議大家看看原始程式碼,pytorch_1.10_gpu\Lib\site-packages\torchvision\transforms\autoaugment.py

裡面有詳細的三組資料增強策略的順序與參數

例如ImageNet的資料增強策略總共有25組變換,共50個變換:

            return [

                (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),

                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),

                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),

                (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),

                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),

                (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),

                (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),

                (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),

                (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),

                (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),

                (("Rotate", 0.8, 8), ("Color", 0.4, 0)),

                (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),

                (("Equalize", 0.0, None), ("Equalize", 0.8, None)),

                (("Invert", 0.6, None), ("Equalize", 1.0, None)),

                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),

                (("Rotate", 0.8, 8), ("Color", 1.0, 2)),

                (("Color", 0.8, 8), ("Solarize", 0.8, 7)),

                (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),

                (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),

                (("Color", 0.4, 0), ("Equalize", 0.6, None)),

                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),

                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),

                (("Invert", 0.6, None), ("Equalize", 1.0, None)),

                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),

                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),

            ]

Copy

特別說明:這裡反復提到的自動資料增強在實際應用中它們是固定的一組變換策略,這是獲得這一群組原則的過程是通過強化學習自動搜素的,所以稱之為自動資料增強策略。

RandAugment

RandAugment是進行N次(num_ops )變換,變換方法從策略池中隨機挑選。pytorch官方文檔對於RandAugment給了較高的評價——“RandAugment is a simple high-performing Data Augmentation technique which improves the accuracy of Image Classification models.”

參數:

num_ops :執行多少次變換

magnitude :每個變換的強度,

num_magnitude_bins:與變化強度的採樣分佈有關

如果對autoaugmentation不熟悉的話,理解RandAugment的參數可能有點困難,這裡結合代碼看一看就知道了。

RandAugment仍舊是一個Module類,來看它的forward()

    def forward(self, img: Tensor) -> Tensor:

        """

            img (PIL Image or Tensor): Image to be transformed.

 

        Returns:

            PIL Image or Tensor: Transformed image.

        """

        fill = self.fill

        if isinstance(img, Tensor):

            if isinstance(fill, (int, float)):

                fill = [float(fill)] * F.get_image_num_channels(img)

            elif fill is not None:

                fill = [float(f) for f in fill]

 

        for _ in range(self.num_ops):

            op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))

            op_index = int(torch.randint(len(op_meta), (1,)).item())

            op_name = list(op_meta.keys())[op_index]

            magnitudes, signed = op_meta[op_name]

            magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0

            if signed and torch.randint(2, (1,)):

                magnitude *= -1.0

            img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)

 

        return img

Copy

前面的程式碼片段主要是根據規則獲取需要進行的變換方法名稱:op_name;變換的強度:magnitude,從

op_index = int(torch.randint(len(op_meta), (1,)).item())

op_name = list(op_meta.keys())[op_index]

Copy

這兩行代碼可以看到,每次採用的變換是隨機的選擇。

而變換強度magnitude則是根據一個區間裡選擇,不同變換方法的強度區間在這裡:

    def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:

        return {

            # op_name: (magnitudes, signed)

            "Identity": (torch.tensor(0.0), False),

            "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),

            "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),

            "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),

            "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),

            "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),

            "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),

            "Color": (torch.linspace(0.0, 0.9, num_bins), True),

            "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),

            "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),

            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),

            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),

            "AutoContrast": (torch.tensor(0.0), False),

            "Equalize": (torch.tensor(0.0), False),

        }

Copy

TrivialAugmentWide

TrivialAugment是採用NAS技術搜索得到的一組資料增強策略,推薦閱讀原文TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation

使用方法也非常簡單,直接看代碼即可。

想瞭解細節,請查看D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\transforms\autoaugment.py

TrivialAugment核心

    def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:

        return {

            # op_name: (magnitudes, signed)

            "Identity": (torch.tensor(0.0), False),

            "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),

            "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),

            "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),

            "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),

            "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),

            "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),

            "Color": (torch.linspace(0.0, 0.99, num_bins), True),

            "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),

            "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),

            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),

            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),

            "AutoContrast": (torch.tensor(0.0), False),

            "Equalize": (torch.tensor(0.0), False),

        }

Copy

小結

本小節詳細剖析transforms運行機制,熟悉內部工作原理,大家可自行編寫變換方法嵌入模型訓練中。同時教授大家學習使用transforms的二十多種方法的方法——授人以漁,最後探討了自動資料增強策略的原理及代碼實踐。

希望大家利用好資料增強,給自己的模型提升性能,切記資料增強的方向是朝著測試集(真實應用場景情況下)的資料分佈、資料情況去變換,切勿盲目應用。

本章節介紹albumentations,但由於本章未涉及圖像分割、目標檢測,以及本章內容也不少了,因此將albumentations放到後續章節,適時進行講解。

預告:原計劃在本章節介紹albumentations,但由於本章未涉及圖像分割、目標檢測,以及本章內容也不少了,因此將albumentations放到後續章節,適時進行講解。

為什麼要用albumentationspytorchtransforms有什麼不足呢? 當然有不足的, pytorchtransforms在處理圖像分割與目標檢測這一類需要圖像與標籤同時變換的時候不太方便。

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()