close

8.5 生成對抗網路——CycleGAN

簡介

本小節將介紹GAN模型中有趣的模型CycleGAN

CycleGAN是一種雙向迴圈的GAN模型,可實現X域與Y域之間的相互轉換,並且是基於unpaired data(即不需要標注,只需要收集圖片)。相較於此前的pix2pixcyclegan適用性更廣,畢竟unpaired datapaired data更容易獲取。

例如論文中展示的,照片與莫内風格畫之間的互相轉換,斑馬與馬之間的轉換,夏天與冬天之間的轉換。

<<AI人工智慧 PyTorch自學>> 8.5 生成對抗網

本節先介紹GANCycleGAN的結構,再通過代碼詳細介紹CycleGAN的訓練、推理。

GAN簡介

GANGenerative Adversarial Nets,生成對抗網路)由 Ian J Goodfellow2014發表於《Generative Adversarial Nets》,可謂是推開了生成模型的一扇大門。

GAN是一種從隨機雜訊生成特定分佈資料的模型,例如生成人臉資料,手寫體資料,自訂資料集等。

GAN當中有GeneratorDiscriminator兩個模型,G負責學習從雜訊到資料的映射,D負責充當損失函數,判斷G生成得是否足夠好,GD交替訓練,形成對抗,同步提升,最終使得G生成的資料越來越像人臉。

根據模型的結構,GAN模型延伸出一系列變體,如本文要介紹的CycleGAN,還有DCGANConditional GANsPix2PixSRGAN等。

GAN的設計十分巧妙,從神經網路訓練的角度考慮,GAN是將損失函數替換為,神經網路的輸出,具體如下圖所示:

<<AI人工智慧 PyTorch自學>> 8.5 生成對抗網

傳統模型訓練,需要用loss_fun(output, label)得到loss值,然後求梯度優化。

GAN中,巧妙了利用一個判別器模型,D_net D_net(output) 趨向0 D_net(training_data)趨向1,依次獲得loss值。

CycleGAN簡介

CycleGAN是一種無監督學習方法,由Jun-Yan Zhu等人於2017年提出。

它的主要思想是通過兩個生成器和兩個判別器來實現兩個不同域之間的圖像轉換。

與其他的GAN模型不同的是,CycleGAN不需要成對的圖像進行訓練,而是只需要兩個域中的任意數量的圖像即可。

下面介紹CycleGAN模型結構損失函數

CycleGAN模型結構

CycleGAN模型由兩個生成器,兩個判別器構成。

生成器G,將X域圖像變換到Y

生成器F,將Y域圖像變換到X

判別器Dx,判別圖像來自X則為1 圖像來自Y則為0

判別器Dy,判別圖像來自X則為0 圖像來自Y則為1

<<AI人工智慧 PyTorch自學>> 8.5 生成對抗網

生成器採用3層卷積+一系列殘差塊構成;

判別器採用PatchGANs,其特點在於對一張圖片不是輸出一個向量,而是輸出NxNxC的張量,NxN分別對應原圖中的70x70區域,即對一張圖,劃分多個70x70patch,每個patch來判別它是0類,還是1類。

CycleGAN損失

cyclegan最大的特點在於損失函數的設計,除了GAN的常規兩個損失函數之外,論文中增加了cycle consistency loss(迴圈一致性損失),用來避免模式坍塌,以及更好的讓GAN模型生成合理的圖像,同時,在官方代碼中還增加了一項identity loss,用來增加GAN模型對於本域圖像資訊的學習。

因此,整體loss8項,分別是'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'

生成器的損失

loss1 :判別器的輸出接近1

對於G,目標是讓對應的判別器D,認為假圖像是真圖像,即輸入是假圖像,標籤是1,目標是欺騙DG就是訓練好了。

<<AI人工智慧 PyTorch自學>> 8.5 生成對抗網

self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)

 

target_tensor = self.get_target_tensor(prediction, target_is_real)   # 根據self.fake_B,生成對應的標籤,即N*N的標籤,為patchGAN的輸出匹配

loss = self.loss(prediction, target_tensor)  # MSELoss() 而非BCE

Copy

loss2F(G(x)) x一致

除了常規Loss,還有cycle consistency loss(迴圈一致性損失),目的是經過G得到的圖片,返回去再經過F,應當是可以恢復得到X域的圖像x_hut,並且xx_hat應當是逐圖元一模一樣的。

這樣的GF才是合理的。

<<AI人工智慧 PyTorch自學>> 8.5 生成對抗網

self.criterionCycle = torch.nn.L1Loss()

self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A   # lambda為縮放係數

Copy

loss3:恒等映射損失

該損失在代碼中才出現,論文中並沒有提到。恒等映射損失的思想是,生成器G_A接收A域圖像,生成B域圖像;若接收B域圖像,應該生成恒等的B域圖像,即B域圖像一模一樣,不能變。

G_A should be identity if real_B is fed: ||G_A(B) - B||

self.idt_A = self.netG_A(self.real_B)

self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt

 

# self.criterionIdt = torch.nn.L1Loss()

Copy

因此,對於生成器G而言,要求它:

  1. 生成的假圖像,要讓判別器預測為1 D(G(x)) 逼近1
  2. G生成的圖像,再經過F生成的圖像,應當等於原圖,此為迴圈一致性損失
  3. 已經是B域的圖像,經過G_A,應當得到B域的原圖。

判別器的loss

判別器損失較為簡單,對於真圖像,需要預測為1,對於假圖像,需要預測為0

其中,假圖像不是基於當前batch的,而是記錄過往的一批假圖像,從假圖像池中抽取。

fake_B = self.fake_B_pool.query(self.fake_B)

self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

 

 

def backward_D_basic(self, netD, real, fake):

    # Real

    pred_real = netD(real)

    loss_D_real = self.criterionGAN(pred_real, True)

    # Fake

    pred_fake = netD(fake.detach())

    loss_D_fake = self.criterionGAN(pred_fake, False)

    # Combined loss and calculate gradients

    loss_D = (loss_D_real + loss_D_fake) * 0.5

    loss_D.backward()

    return loss_D

Copy

訓練注意事項

論文中超參數:batch size =1epoch =200lr100epoch,固定0.0002,後100epoch,線性下降至0

其它注意事項:

  • 兩個域圖像是否有一致性: 舉個例子:

蘋果 <=> 橘子: 都是球形, OK!

 

蘋果 <=> 香蕉: Mode Collapse!

Copy

  • 訓練CycleGAN要有耐心
  • 學習率別太高
  • 對抗損失權重不要太高,迴圈一致性損失權重為1的時候,對抗損失一般設置為0.1
  • 判別器優化頻率高於生成器
  • 使用最小二乘損失(MSE
  • cycleGANloss不能準確反應訓練的好壞,不代表著訓練進度,甚至不能代表結果優劣。所以還是要輸出樣張看效果,或許可以借鑒WGAN的思想
  • 由於 minimax 優化的性質,許多 GAN 損失不會收斂(例外:WGANWGAN-GP 等)。對於 DCGAN LSGAN 目標,G D 損失上下波動是很正常的。只要不爆炸應該沒問題。

CycleGAN代碼實現

接下來,通過pytorch訓練一個可以將圖片轉換為莫内風格圖像的CycleGANgithub已經19.5K star了,可見深受大家喜愛。

資料集準備

由於不需要標籤,僅需要準備圖像,所以在根目錄下,存放trainA, trainB, testA, testB即可,分別存放A域的圖像,B域的圖像。

這裡下載官方提供的monet2photo資料集,可以通過sh腳本下載,也可以手動下載(推薦)

# 方法一:bash

bash ./datasets/download_cyclegan_dataset.sh monet2photo

 

# 方法二:手動下載

# apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo

http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/$FILE.zip

# 例如莫内資料下載

http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/monet2photo.zip

Copy

數據載入

整個資料模組代碼設計如下圖所示:

&lt;&lt;AI人工智慧 PyTorch自學&gt;&gt; 8.5 生成對抗網

該項目適配pix2pix cyclegan,因此提供了多種dataset,所有的dataset都繼承於BaseDataset,針對cyclegan的是unaligned_dataset.py中的UnalignedDataset

對於dataloader,提供了一個類 CustomDatasetDataLoader,並且實現了反覆運算協議iter,因此"dataloader"是自訂的一個可反覆運算物件。

在主代碼01_train.py中,通過33行代碼:dataloader = create_dataset(opt) ,實現dataloader的創建,所有的配置資訊存放在opt中。


接下來關注UnalignedDataset,它內部實現了transformtransformopt的參數決定

  • 第一步:縮放變換,有resize,或者基於width縮放的方式;默認基於resize_and_crop resize的尺寸是opt.load_size
  • 第二步:crop的尺寸是 opt.crop_size
  • 第三步:Normalize

這份代碼中有一個值得借鑒的是,通過參數配置,來選擇調用具體的類。實現方法是通過,importlib.import_module實現通過字串形式import工具庫。

def find_dataset_using_name(dataset_name):

    """Import the module "data/[dataset_name]_dataset.py".

 

    In the file, the class called DatasetNameDataset() will

    be instantiated. It has to be a subclass of BaseDataset,

    and it is case-insensitive.

    """

    dataset_filename = "data." + dataset_name + "_dataset"

    datasetlib = importlib.import_module(dataset_filename)  # 這裡的dataetlib,等同於一個庫,例如。import cv2cv2, import torchtorch

 

    dataset = None

    target_dataset_name = dataset_name.replace('_', '') + 'dataset'

    for name, cls in datasetlib.__dict__.items():

        if name.lower() == target_dataset_name.lower() \

           and issubclass(cls, BaseDataset):

            dataset = cls

 

    if dataset is None:

        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

 

    return dataset

Copy

模型構建

原始程式碼中將模型(nn.Module), 損失函數,優化器一併放到了CycleGANModel類當中,對外提供set_input()optimize_parameters(),實現前向傳播、損失計算、反向傳播,這樣可以讓主代碼更簡潔。

模型部分代碼設計如下圖所示

&lt;&lt;AI人工智慧 PyTorch自學&gt;&gt; 8.5 生成對抗網

 

對於 netG_A/Bcyclegan中是resnetblock構成的生成器,詳細可見models/networks.pyResnetGenerator類, 主要由resnetblock的下採樣和TransConv的上採樣構成,最後加入tanh()啟動函數。

對於netD_A/B 是一個patchGAN,全部由卷積層構成的全卷積網路,詳見 models/networks.pyNLayerDiscriminator

整個模型構建與優化核心代碼如下:

model = create_model(opt)      # create a model given opt.model and other options

model.setup(opt)               # regular setup: load and print networks; create schedulers

----------------------------------

model.set_input(data)         # unpack data from dataset and apply preprocessing

model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

Copy

模型訓練

資料、模型、損失函數與優化器準備完畢,可以進行反覆運算訓練。

如果在windows下訓練,需要開啟 visdom

python -m visdom.server

Copy

由於此處使用了visdom進行視覺化,需要先行開啟visdom,否則會報錯:

requests.exceptions.ConnectionError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /env/main (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x000001D60D9110D0>: Failed to establish a new connection: [WinError 10061] 由於目的電腦積極拒絕,無法連接。'))

[WinError 10061] 由於目的電腦積極拒絕,無法連接。

Copy

訓練指令:

python 01_train.py --n_epochs 200 --dataroot path/to/your/datasets/monet2photo --name monet2photo_cyclegan --model cycle_gan

Copy

日誌資訊、模型資訊將存儲於checkpoints\monet2photo_cyclegan

訓練結果

cycleGANloss不能準確反應訓練的好壞,不代表著訓練進度,甚至不能代表結果優劣,整體趨勢是,cycle loss逐漸下降,如下圖所示

&lt;&lt;AI人工智慧 PyTorch自學&gt;&gt; 8.5 生成對抗網

推理測試

預訓練模型可以從這裡下載:連結:https://pan.baidu.com/s/1bEPNBbAeqMumpM2pqKwb4w 提取碼:q159

python 02_inference.py --dataroot G:\deep_learning_data\cyclegan\monet2photo\testB --name monet2photo_cyclegan  --model test --no_dropout --model_suffix _B180

Copy

model_suffix 格式說明:模型模型檔案名保存為latest_net_G_A.pthlatest_net_G_B.pth。在代碼中會自動拼接:

"latest_net_G{}.pth".format(model_suffix)

Copy

最後在F:\pytorch-tutorial-2nd\code\chapter-8\cyclegan\results\monet2photo_cyclegan\test_latest下就有對應的結果圖片

這裡展示了120 180 200epoch時的展示效果,感覺180時的效果最好。

&lt;&lt;AI人工智慧 PyTorch自學&gt;&gt; 8.5 生成對抗網

小結

本小結先介紹了GANcycleGAN的模型結構,GAN是一個巧妙的利用神經網路進行損失計算的設計,CycleGAN是巧妙的利用了兩個GAN相互轉換,並提出迴圈一致性loss,最終CycleGAN的損失共8個,分別是'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'

然後介紹CycleGAN原始程式碼使用及設計,其將Dataset, DataLoader, model, loss, optim進行了高度封裝,使主代碼很簡潔。

從此可見,無論多複雜、難理解的pytorch模型訓練代碼,都離不開Dataset, DataLoader, nn.Moduleloss, optim,只要瞭解訓練的步驟,這些複雜的代碼都可以梳理出來。

2014GAN提出以來,往後的5年間提出了各式各樣的GAN變體,也有了非常多有趣的應用,感興趣的朋友可以進一步瞭解。

2020年之前,在圖像生成領域,GAN是當之無愧的主流,但2020年《Denoising Diffusion Probabilistic Models》(Diffusion)提出後,基於擴散模型(diffusion model的圖像生成稱為了學術界的寵兒,包括OpenAI提出的DALL-E系列,stability.ai提出的Stable-Diffusion

下一節將介紹擴散模型(diffusion model及代碼實現

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()