close

8.6 Diffusion Model——DDPM

前言

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

2020年,DDPM橫空出世,將擴散模型概念迅速擴散到深度學習各個領域,並在2022年隨著Stable Diffusion的提出及開源,該技術實現了破圈,走入了大眾的視野,而不再是科技工作者才瞭解的概念。

為此,對擴散模型原理進行簡要介紹,並通過代碼實現DDPM模型,同時介紹Stable Diffusion 模型背後的原理。

本文主要內容包括

  1. Diffusion Model 概念介紹
  2. DDPM 模型原理及代碼實現,訓練、推理
  3. Guided Diffusion:引導條件的擴散模型介紹,包括classifier-base classifier-free 兩大主流模型
  4. Stable Diffusion:讓技術出圈的模型
  5. Latent Diffusion ModelLDM):Stable Diffusion背後的核心技術

Diffusion Model 簡介

擴散模型(Diffusion Model)發展至今已成為一個大的概念、思想,擴散是借鑒物理學中的擴散過程(Diffusion Process)概念。

物理學中擴散是一種物質分子在非均勻環境中的運動,物質從高濃度區域向低濃度區域傳輸,最終實現濃度均衡。

在深度學習中,則是將雜訊加入到原始圖像中進行擴散,最終使圖片變為雜訊,然後利用深度學習模型學習從雜訊變到圖像的過程,最後可以隨機生成雜訊,並利用模型將雜訊生成圖片的過程。

深度學習中擴散的概念自2015就有了,並在2019年發表於論文《Generative Modeling by Estimating Gradients of the Data Distribution》,

最終在2020年的《Denoising Diffusion Probabilistic Models》中被大眾熟知,隨後就開啟了擴散模型的學術界擴散,如DALL-E 2imagen Stable Diffusion等強大應用。

DDPM 實現雜訊到圖片步驟

此處借鑒李宏毅教授2023年春季期ML課程中課件進行講解。

DDPM模型推理過程是將一個標準正態分佈中採樣的雜訊圖片(與原圖同尺寸),經過T步(1000步)的去噪(Denoising),生成高品質圖像的過程。

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

DDPM模型推理過程,可以看似將雜訊逐步的去除,先獲得圖像大體輪廓,逐步精雕細琢,獲得清晰的圖像。

這就像雕像製作過程,工匠常說:雕像本身就在石頭裡,我只是把多餘的部分剔除掉,雕刻雕像的過程就像雜訊變到高品質圖像的過程,一開始它們都是可以生成萬物的本源。

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

如何對Denoise模組進行數學建模,使得雜訊逐步變清晰?

可以這麼做,設計一個神經網路,它接收雜訊圖以及當前步數,輸出一個雜訊,然後與原圖相減,獲得更清晰的圖片。

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

如何訓練這樣的神經網路模型?訓練資料如何構建?

前面提到雜訊如何變圖像,現在反過來看看,圖片如何變雜訊的過程

對於原圖,經過T步的逐漸加高斯雜訊,使圖像逐步模糊,最終趨近于標準高斯分佈。

這其中就可以構建Noise Predicter的訓練資料,例如藍色框中為輸入,紅色框雜訊則是反向過程時期望預測的標籤。

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

對於具體模型,DDPM中採用了Unet架構的神經網路實現資料預測。

到這裡,DDPM實現從雜訊生成圖像的步驟就清晰了:

  1. 前向過程:將原圖逐步添加雜訊, 1000
  2. 反向過程:利用神經網路學習加噪圖像到雜訊的變換,使得模型可以去噪
  3. 推理使用:隨機採樣,得到高斯雜訊,然後逐步去噪,經過1000步去噪,得到清晰圖像。

DDPM 公式理解

根據上述步驟,可以將DDPM訓練、推理過程採用數學形式表達,如下圖所示

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

訓練過程:

  • q(x0) 表示原始圖像資料集(分佈),x0表示一張原始圖像
  • t 看成是11000的均勻分佈採樣
  • ε 表示從標準正態分佈中採樣得到的雜訊圖像
  • εθ 表示模型需要學習到的雜訊圖像,該圖像是利用unet生成的,unet接收上一步去噪圖與當前步數t,預測出一個雜訊圖像,並且期望它與高斯雜訊越接近越好。即ε - εθ 趨於0
  • αt_bar:均值係數,可由重參數方法訓練而來,或是固定值。固定值如0.0001 0.02線性插值。

推理過程:

  • xT:從正態分佈中隨機採樣的雜訊
  • z:從正態分佈中隨機採樣的雜訊
  • xt-1:主要是由Xt減去模型生成的雜訊圖像,並且以一定的權重加權後,加上標準差乘以隨機雜訊。至於原因需要看原文及進行公式推導理解了

更多公式推導,推薦閱讀

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

DDPM 模型結構

下面通過代碼進行DDPM模型結構的剖析,官方代碼為TF,在這裡採用非官方的PyTorch

論文採用TPU v3-8(相當於8V100 GPU),在cifar10上花了10.6小時,由此可見,要想在256x256的圖片上訓練,會非常耗時。

為了快速使用DDPM,這裡採用cifar10進行學習。

通過代碼分析,DDPM模型結構如下圖所示,是在unet結構上進行了一些改進,包括加入時間步tembedding,卷積中採用了ResBlock,並且採用了Self-Attention機制。

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

如圖例所示,模型整體有7個元件,彩色箭頭是一組操作,包含2-3個網路層的堆疊,通常最後一個網路層才會改變圖像解析度。

第一個,時間步的embedding,它會輸入到除了headtail的其它網路層當中,並且是add的形式添加的(h += self.temb_proj(temb)[:, :, None, None]

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

第二個,head模組,是一個3x3卷積,主要是為了改變通道,沒有特殊的地方。

第三個,down block,是下採樣的核心,一個block2ResBlock與一個下採樣層構成。ResBlock內部如圖所示:

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

第四個,middle block,由兩個ResBlock構成

self.middleblocks = nn.ModuleList([

    ResBlock(now_ch, now_ch, tdim, dropout, attn=True),

    ResBlock(now_ch, now_ch, tdim, dropout, attn=False),])

Copy

第五個,Up block,由3ResBlock+1個上採樣層,

第六個,tail,由GN + swish + conv構成,輸出最終圖像

self.tail = nn.Sequential(

    nn.GroupNorm(32, now_ch),

    Swish(),

    nn.Conv2d(now_ch, 3, 3, stride=1, padding=1))

Copy

第七個,concat,是unetlow-level特徵融合到high-level特徵當中

總的來說,特色在於時間步tembedding是加入了每一個ResBlock中進行使用,並且ResBlock採用了self-attention機制。

DDPM——訓練Cifar-10

接下來使用配套代碼中的Main.py進行訓練,並且使用Main.py進行推理,訓練和推理需要調整"state": "eval"

  • 數據準備:在Main.py同級目錄下創建cifar資料夾,並且將cifar-10-python.tar.gz放到裡邊。
  • 運行訓練:python Main.py1080ti上訓練,約16小時,訓練完畢,在Checkpoints文件下有ckpt_last_.pt

下面觀察cifar-10的訓練細節,在配套代碼中可以看到,資料採用torchvision提供的cifar10 dataset介面,模型為Unet,優化器為AdamW,學習率為warmup+consine

在主迴圈中,只用了imageslabels是沒有使用到的

模型的反覆運算,封裝在了GaussianDiffusionTrainer類的forward函數,這也是核心代碼之一,下面詳細看看forward函數。

5行:進行時間步的採樣,即為每一個樣本配一個時間步,並不需要為每個樣本採樣1000個時間步進行訓練,這是因為公式推導的時候,xt可以用x0直接表示的,不需要依賴xt-1

6行:對需要加入的雜訊進行採樣,這裡為標準正態分佈

11行:根據公式計算得到x_tx_tx0noise加權得到,細節可看公式理解部分,這裡的兩個權重之和不等於一,但是接近1

14行:模型接收x_tt,去預測雜訊圖像,並且通過mse_loss進行損失計算。

def forward(self, x_0):

    """

    Algorithm 1.

    """

    t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)  # 不需要1000步都訓練,隨機batchsize

    noise = torch.randn_like(x_0)  # 標準正態分佈

    # 基於x0,獲得xt 隨後得到訓練資料[(xt, t, noise), ]

    # x_t.shape     [bs, 3, 32, 32]

    # noise.shape   [bs, 3, 32, 32]

    # t.shape       (bs,)

    x_t = (extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +

           extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)

 

    loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')

    return loss

Copy

代碼與上文的原理介紹是一致的,時間步step2與加噪後的圖像是輸入資料,標籤ground truthnoise

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

DDPM——推理Cifar-10

訓練完畢後,先看看推理效果。

首先在Main.py中修改 "state": "eval", # train or eval,然後運行python Main.py,即可在"sampled_dir": "./SampledImgs/ 資料夾下獲得如下圖片,看上去還像個樣子,畢竟資料量、算力、時間擺在這裡。

這裡提供一個訓練好的模型參數,ckptlast.pt,下載後放到Checkpoints資料夾。連結:https://pan.baidu.com/s/17X_L9oH4lmrGwnD-V9D5HQ 提取碼:w4ki

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

推理過程的代碼理解相對有一點繞,主要還是參照論文中的sampling過程,如紅框所示,首先獲得均值(為什麼叫均值?可能要去推一下公式了),然後加上時間步對應的標準差乘以隨機雜訊。

其中,均值主要是由Xt減去模型生成的雜訊圖像,並且以一定的權重加權後得到。

<<AI人工智慧 PyTorch自學>> 8.6 Diffu

核心代碼涉及3個函數,

  1. forward()為主函數。
  2. p_mean_variance()為調用Unet模型,獲得meanvar
  3. predict_xt_prev_mean_from_eps()是進行上圖中紅色框運算的過程。

26行,獲得模型預測的雜訊圖片

27行,獲得mean,即上圖中的紅色框

15行,加上標準差乘以隨機雜訊,獲得t時刻的輸出,反復反覆運算1000次,得到最終輸出圖像。

def forward(self, x_T):

    """

    Algorithm 2.

    """

    x_t = x_T

    for time_step in reversed(range(self.T)):

        print(time_step)

        t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step

        mean, var = self.p_mean_variance(x_t=x_t, t=t)  # mean xt 減去 雜訊圖

        # no noise when t == 0

        if time_step > 0:

            noise = torch.randn_like(x_t)

        else:

            noise = 0

        x_t = mean + torch.sqrt(var) * noise

        assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."

    x_0 = x_t

    return torch.clip(x_0, -1, 1)  

 

def p_mean_variance(self, x_t, t):

    # below: only log_variance is used in the KL computations

    # posterior_var: betas計算得到,betas=[0.0001 to 0.02]

    var = torch.cat([self.posterior_var[1:2], self.betas[1:]])  # betas=[0.0001 to 0.02]

    var = extract(var, t, x_t.shape)

 

    eps = self.model(x_t, t)  # epsunet輸出的圖像

    xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)  # 加權減法,xt 減去 雜訊圖

 

    return xt_prev_mean, var

 

def predict_xt_prev_mean_from_eps(self, x_t, t, eps):

    assert x_t.shape == eps.shape

    return (

        extract(self.coeff1, t, x_t.shape) * x_t -

        extract(self.coeff2, t, x_t.shape) * eps

    )

Copy

Diffusion Model 拓展 —— Guided Diffusion

guided diffusion是加入了引導資訊,讓生成的圖像變為我們想要的形式,而不是隨機的圖片。

引導式的擴散模型從有無分類器,可以分為兩種,classifier-baseclassifier-freeclassifier-free由於不需要分類器,引導資訊直接embedding到模型中,所以應用更為廣泛。

classifier-base ——Diffusion Models Beat GANs on Image Synthesis

DDPM提出後,其實效果並未驚豔大家,在DDPM發表後的幾個月,Diffusion Models Beat GANs on Image Synthesis的發表, github,把擴散模型帶入了高潮,因為它效果比GAN更好,並且針對DDPM,引入了classifier-guidance思想,可以在生成時加入條件約束,可控制生成特定類別的圖像。

具體公式詳見原文。在使用時,採用unet估計mean時,需要額外加上分類器的分類結果的梯度,詳見openaigithub:https://github.com/openai/guided-diffusion

4行:均值除了unet的,還需要加入分類器得到的梯度

7行:分類器推理,計算梯度過程,這裡有個重要參數是args.classifier_scale

# guided_diffusion/gaussian_diffusion.py

def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):

    gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)

    new_mean = (p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float())

    return new_mean

# scripts/classifier_sample.py

def cond_fn(x, t, y=None):

    assert y is not None

    with th.enable_grad():

        x_in = x.detach().requires_grad_(True)

        logits = classifier(x_in, t)

        log_probs = F.log_softmax(logits, dim=-1)

        selected = log_probs[range(len(logits)), y.view(-1)]

        return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

Copy

classifier-free —— classifier free diffusion guidance

由於classifier-base需要訓練分類器,並且在推理時有超參數args.classifier_scale的影響,以及引導條件的加入過於單一,沒有辦法通用性的加入各類條件。

為此,穀歌大腦的兩位工程師提出了classifier free的方式,文中將分類資訊通過embedding的方式加入到模型中訓練,這裡類似時間步tembedding

訓練時會結合有條件與無條件進行訓練,無條件則將分類標籤embedding全部設置為0,具體細節可參見論文。

由於論文中沒有提供代碼,所以找到的代碼是這個DDPM,其中的condition模式就是classifier-free

2行:訓練時,有10%的是無條件的,90%是有條件的

9行:標籤資訊與時間步一樣,通過embedding匯入模型中,稱為引導資訊

# DiffusionFreeGuidence/TrainCondition.py

if np.random.rand() < 0.1:

    labels = torch.zeros_like(labels).to(device)

 

# DiffusionFreeGuidence/ModelCondition.py

def forward(self, x, t, labels):

    # Timestep embedding

    temb = self.time_embedding(t)

    cemb = self.cond_embedding(labels)

    # Downsampling

    h = self.head(x)

    hs = [h]

    for layer in self.downblocks:

        h = layer(h, temb, cemb)

        hs.append(h)

    ...

Copy

classifier free diffusion是打開了一扇大門,既然類別標籤可以embedding,那麼文本資訊也可以通過該方式注入模型中進行引導,火爆的Stable Diffusion就是這麼做的。

Diffusion Model 拓展 —— Stable Diffusion

Stable Diffusion 2022年火爆全球的文圖生成(text-to-image)擴散模型,由於它開源,並且效果相當炸裂,因此已經被大多數人使用。

Stable Diffusion 背後的技術是LDMlatent diffusion model,之所以叫Stable Diffusion,或許與其背後的公司由Stability AI有關。

Stable Diffusion 是由CompVisStability AILAION三家公司共同創建,CompVis提供的技術LDMlatent diffusion model)源自論文High-Resolution Image Synthesis with Latent Diffusion Models,對應的githubLAION公司是一家致力於推動人工智慧和資料科學發展的科技公司,其從互聯網上抓取的 58 億「圖像-文本」資料,並開源了 LAION-5B資料集。而Stability AI的貢獻,或許是出錢出力出人吧。

Stable Diffusion 的開原始程式碼https://github.com/CompVis/stable-diffusion  LDMlatent diffusion model)的開原始程式碼:https://github.com/CompVis/latent-diffusion都在CompVis下,代碼幾乎一樣。

下面簡要介紹Stable Diffusion用到的latent diffusion model技術。

LDM之前,擴散模型在圖元域進行擴散與去噪,這樣的計算量過大。因此,考慮將擴散過程放到隱空間(latent space),即將資料經過encoder,來到特徵空間,在特徵空間上進行擴散和去噪。

這樣一來,有以下好處:

  1. 計算量減小,訓練和推理速度變快
  2. 可以加入更多引導資訊,例如文本資訊。

LDM論文中有一幅圖很好的解釋了LDM的思想:首先在pixel space,需要有encoderdecoder,在latent space採用了多頭注意力機制,並且除了時間步資訊,加入了conditioning模組,其中的引導資訊可以是文本、圖片、表徵向量等等一切內容,然後為引導資訊配一個embedding模組,就可以將引導資訊加入模型中。

&lt;&lt;AI人工智慧 PyTorch自學&gt;&gt; 8.6 Diffu

這裡配上李巨集毅老師的結構示意圖,可知道LDM的核心在於2當中,處理的不再是圖元空間,而是一個特徵空間

&lt;&lt;AI人工智慧 PyTorch自學&gt;&gt; 8.6 Diffu

stable diffusion 的使用與安裝,網上有太多教程,這裡不進行介紹,主要瞭解LDM的架構。推薦閱讀:文生圖模型之Stable Diffusion

Stable Diffusion同一時期,叫得上名稱的文圖生成模型還有MidjourneyDALL-E 2,不過它們都是不開源的。

小結

本案例借助DDPM的代碼剖析,瞭解擴散模型實現去噪,從而生成圖像的過程和原理,並且對Guided Diffusion Model進行介紹,模型要能根據我們的指示生成特定的圖像,這樣的模型才有更大的應用價值。

Guided Diffusion Model中,包含classifier-base classifier-freeclassifier-free是後來的主流。

classifier-free的代表即出圈的Stable DiffusionStable Diffusion是完全開源的,因此得到了全球的使用與關注。

在擴散模型中,LDMlatent diffusion model思想值得仔細研究,它將一切資訊都放到隱空間(特徵空間)進行處理,使得圖片處理起來更小,還可以進行多模態處理。

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()