8.6 Diffusion Model——DDPM
前言
2020年,DDPM橫空出世,將擴散模型概念迅速擴散到深度學習各個領域,並在2022年隨著Stable Diffusion的提出及開源,該技術實現了破圈,走入了大眾的視野,而不再是科技工作者才瞭解的概念。
為此,對擴散模型原理進行簡要介紹,並通過代碼實現DDPM模型,同時介紹Stable Diffusion 模型背後的原理。
本文主要內容包括
- Diffusion Model 概念介紹
- DDPM 模型原理及代碼實現,訓練、推理
- Guided Diffusion:引導條件的擴散模型介紹,包括classifier-base 和 classifier-free 兩大主流模型
- Stable Diffusion:讓技術出圈的模型
- Latent Diffusion Model(LDM):Stable Diffusion背後的核心技術
Diffusion Model 簡介
擴散模型(Diffusion Model)發展至今已成為一個大的概念、思想,擴散是借鑒物理學中的擴散過程(Diffusion Process)概念。
物理學中擴散是一種物質分子在非均勻環境中的運動,物質從高濃度區域向低濃度區域傳輸,最終實現濃度均衡。
在深度學習中,則是將雜訊加入到原始圖像中進行擴散,最終使圖片變為雜訊,然後利用深度學習模型學習從雜訊變到圖像的過程,最後可以隨機生成雜訊,並利用模型將雜訊生成圖片的過程。
深度學習中擴散的概念自2015就有了,並在2019年發表於論文《Generative Modeling by Estimating Gradients of the Data Distribution》,
最終在2020年的《Denoising Diffusion Probabilistic Models》中被大眾熟知,隨後就開啟了擴散模型的學術界擴散,如DALL-E 2,imagen, Stable Diffusion等強大應用。
DDPM 實現雜訊到圖片步驟
此處借鑒李宏毅教授2023年春季期ML課程中課件進行講解。
DDPM模型推理過程是將一個標準正態分佈中採樣的雜訊圖片(與原圖同尺寸),經過T步(1000步)的去噪(Denoising),生成高品質圖像的過程。
DDPM模型推理過程,可以看似將雜訊逐步的去除,先獲得圖像大體輪廓,逐步精雕細琢,獲得清晰的圖像。
這就像雕像製作過程,工匠常說:“雕像本身就在石頭裡,我只是把多餘的部分剔除掉”,雕刻雕像的過程就像雜訊變到高品質圖像的過程,一開始它們都是可以生成“萬物”的本源。
如何對Denoise模組進行數學建模,使得雜訊逐步變清晰?
可以這麼做,設計一個神經網路,它接收雜訊圖以及當前步數,輸出一個雜訊,然後與原圖相減,獲得更清晰的圖片。
如何訓練這樣的神經網路模型?訓練資料如何構建?
前面提到雜訊如何變圖像,現在反過來看看,圖片如何變雜訊的過程。
對於原圖,經過T步的逐漸加高斯雜訊,使圖像逐步模糊,最終趨近于標準高斯分佈。
這其中就可以構建Noise Predicter的訓練資料,例如藍色框中為輸入,紅色框雜訊則是反向過程時期望預測的標籤。
對於具體模型,DDPM中採用了Unet架構的神經網路實現資料預測。
到這裡,DDPM實現從雜訊生成圖像的步驟就清晰了:
- 前向過程:將原圖逐步添加雜訊, 共1000步
- 反向過程:利用神經網路學習加噪圖像到雜訊的變換,使得模型可以去噪
- 推理使用:隨機採樣,得到高斯雜訊,然後逐步去噪,經過1000步去噪,得到清晰圖像。
DDPM 公式理解
根據上述步驟,可以將DDPM訓練、推理過程採用數學形式表達,如下圖所示
訓練過程:
- q(x0) 表示原始圖像資料集(分佈),x0表示一張原始圖像
- t 看成是1到1000的均勻分佈採樣
- ε 表示從標準正態分佈中採樣得到的雜訊圖像
- εθ 表示模型需要學習到的雜訊圖像,該圖像是利用unet生成的,unet接收上一步去噪圖與當前步數t,預測出一個雜訊圖像,並且期望它與高斯雜訊越接近越好。即ε - εθ 趨於0。
- αt_bar:均值係數,可由重參數方法訓練而來,或是固定值。固定值如0.0001 到0.02線性插值。
推理過程:
- xT:從正態分佈中隨機採樣的雜訊
- z:從正態分佈中隨機採樣的雜訊
- xt-1:主要是由Xt減去模型生成的雜訊圖像,並且以一定的權重加權後,加上標準差乘以隨機雜訊。至於原因需要看原文及進行公式推導理解了
更多公式推導,推薦閱讀
DDPM 模型結構
下面通過代碼進行DDPM模型結構的剖析,官方代碼為TF版,在這裡採用非官方的PyTorch版。
論文採用TPU v3-8(相當於8張V100 GPU),在cifar10上花了10.6小時,由此可見,要想在256x256的圖片上訓練,會非常耗時。
為了快速使用DDPM,這裡採用cifar10進行學習。
通過代碼分析,DDPM模型結構如下圖所示,是在unet結構上進行了一些改進,包括加入時間步t的embedding,卷積中採用了ResBlock,並且採用了Self-Attention機制。
如圖例所示,模型整體有7個元件,彩色箭頭是一組操作,包含2-3個網路層的堆疊,通常最後一個網路層才會改變圖像解析度。
第一個,時間步的embedding,它會輸入到除了head,tail的其它網路層當中,並且是add的形式添加的(h += self.temb_proj(temb)[:, :, None, None])
第二個,head模組,是一個3x3卷積,主要是為了改變通道,沒有特殊的地方。
第三個,down block,是下採樣的核心,一個block由2個ResBlock與一個下採樣層構成。ResBlock內部如圖所示:
第四個,middle block,由兩個ResBlock構成
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),])
Copy
第五個,Up block,由3個ResBlock+1個上採樣層,
第六個,tail,由GN + swish + conv構成,輸出最終圖像
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1))
Copy
第七個,concat,是unet的low-level特徵融合到high-level特徵當中
總的來說,特色在於時間步t的embedding是加入了每一個ResBlock中進行使用,並且ResBlock採用了self-attention機制。
DDPM——訓練Cifar-10
接下來使用配套代碼中的Main.py進行訓練,並且使用Main.py進行推理,訓練和推理需要調整"state": "eval"。
- 數據準備:在Main.py同級目錄下創建cifar資料夾,並且將cifar-10-python.tar.gz放到裡邊。
- 運行訓練:python Main.py,1080ti上訓練,約16小時,訓練完畢,在Checkpoints文件下有ckpt_last_.pt。
下面觀察cifar-10的訓練細節,在配套代碼中可以看到,資料採用torchvision提供的cifar10 dataset介面,模型為Unet,優化器為AdamW,學習率為warmup+consine。
在主迴圈中,只用了images,labels是沒有使用到的。
模型的反覆運算,封裝在了GaussianDiffusionTrainer類的forward函數,這也是核心代碼之一,下面詳細看看forward函數。
第5行:進行時間步的採樣,即為每一個樣本配一個時間步,並不需要為每個樣本採樣1000個時間步進行訓練,這是因為公式推導的時候,xt可以用x0直接表示的,不需要依賴xt-1。
第6行:對需要加入的雜訊進行採樣,這裡為標準正態分佈
第11行:根據公式計算得到x_t,x_t由x0與noise加權得到,細節可看公式理解部分,這裡的兩個權重之和不等於一,但是接近1。
第14行:模型接收x_t與t,去預測雜訊圖像,並且通過mse_loss進行損失計算。
def forward(self, x_0):
"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # 不需要1000步都訓練,隨機batchsize個
noise = torch.randn_like(x_0) # 標準正態分佈
# 基於x0,獲得xt, 隨後得到訓練資料[(xt, t, noise), ]
# x_t.shape [bs, 3, 32, 32]
# noise.shape [bs, 3, 32, 32]
# t.shape (bs,)
x_t = (extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
Copy
代碼與上文的原理介紹是一致的,時間步step2與加噪後的圖像是輸入資料,標籤ground truth是noise。
DDPM——推理Cifar-10
訓練完畢後,先看看推理效果。
首先在Main.py中修改 "state": "eval", # train or eval,然後運行python Main.py,即可在"sampled_dir": "./SampledImgs/ 資料夾下獲得如下圖片,看上去還像個樣子,畢竟資料量、算力、時間擺在這裡。
這裡提供一個訓練好的模型參數,ckptlast.pt,下載後放到Checkpoints資料夾。連結:https://pan.baidu.com/s/17X_L9oH4lmrGwnD-V9D5HQ 提取碼:w4ki
推理過程的代碼理解相對有一點繞,主要還是參照論文中的sampling過程,如紅框所示,首先獲得均值(為什麼叫均值?可能要去推一下公式了),然後加上時間步對應的標準差乘以隨機雜訊。
其中,均值主要是由Xt減去模型生成的雜訊圖像,並且以一定的權重加權後得到。
核心代碼涉及3個函數,
- forward()為主函數。
- p_mean_variance()為調用Unet模型,獲得mean和var。
- predict_xt_prev_mean_from_eps()是進行上圖中紅色框運算的過程。
第26行,獲得模型預測的雜訊圖片
第27行,獲得mean,即上圖中的紅色框
第15行,加上標準差乘以隨機雜訊,獲得t時刻的輸出,反復反覆運算1000次,得到最終輸出圖像。
def forward(self, x_T):
"""
Algorithm 2.
"""
x_t = x_T
for time_step in reversed(range(self.T)):
print(time_step)
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
mean, var = self.p_mean_variance(x_t=x_t, t=t) # mean是 xt圖 減去 雜訊圖
# no noise when t == 0
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise
assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
x_0 = x_t
return torch.clip(x_0, -1, 1)
def p_mean_variance(self, x_t, t):
# below: only log_variance is used in the KL computations
# posterior_var: 由 betas計算得到,betas=[0.0001 to 0.02]
var = torch.cat([self.posterior_var[1:2], self.betas[1:]]) # betas=[0.0001 to 0.02]
var = extract(var, t, x_t.shape)
eps = self.model(x_t, t) # eps是unet輸出的圖像
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps) # 加權減法,xt圖 減去 雜訊圖
return xt_prev_mean, var
def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
extract(self.coeff1, t, x_t.shape) * x_t -
extract(self.coeff2, t, x_t.shape) * eps
)
Copy
Diffusion Model 拓展 —— Guided Diffusion
guided diffusion是加入了引導資訊,讓生成的圖像變為我們想要的形式,而不是隨機的圖片。
引導式的擴散模型從有無分類器,可以分為兩種,classifier-base和classifier-free,classifier-free由於不需要分類器,引導資訊直接embedding到模型中,所以應用更為廣泛。
classifier-base ——《Diffusion Models Beat GANs on Image Synthesis》
DDPM提出後,其實效果並未驚豔大家,在DDPM發表後的幾個月,《Diffusion Models Beat GANs on Image Synthesis》的發表, 其github,把擴散模型帶入了高潮,因為它效果比GAN更好,並且針對DDPM,引入了classifier-guidance思想,可以在生成時加入條件約束,可控制生成特定類別的圖像。
具體公式詳見原文。在使用時,採用unet估計mean時,需要額外加上分類器的分類結果的梯度,詳見openai的github:https://github.com/openai/guided-diffusion
第4行:均值除了unet的,還需要加入分類器得到的梯度
第7行:分類器推理,計算梯度過程,這裡有個重要參數是args.classifier_scale
# guided_diffusion/gaussian_diffusion.py
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
new_mean = (p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float())
return new_mean
# scripts/classifier_sample.py
def cond_fn(x, t, y=None):
assert y is not None
with th.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
Copy
classifier-free —— 《classifier free diffusion guidance》
由於classifier-base需要訓練分類器,並且在推理時有超參數args.classifier_scale的影響,以及引導條件的加入過於單一,沒有辦法通用性的加入各類條件。
為此,穀歌大腦的兩位工程師提出了classifier free的方式,文中將分類資訊通過embedding的方式加入到模型中訓練,這裡類似時間步t的embedding。
訓練時會結合有條件與無條件進行訓練,無條件則將分類標籤embedding全部設置為0,具體細節可參見論文。
由於論文中沒有提供代碼,所以找到的代碼是這個DDPM,其中的condition模式就是classifier-free。
第2行:訓練時,有10%的是無條件的,90%是有條件的
第9行:標籤資訊與時間步一樣,通過embedding匯入模型中,稱為引導資訊。
# DiffusionFreeGuidence/TrainCondition.py
if np.random.rand() < 0.1:
labels = torch.zeros_like(labels).to(device)
# DiffusionFreeGuidence/ModelCondition.py
def forward(self, x, t, labels):
# Timestep embedding
temb = self.time_embedding(t)
cemb = self.cond_embedding(labels)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb, cemb)
hs.append(h)
...略
Copy
classifier free diffusion是打開了一扇大門,既然類別標籤可以embedding,那麼文本資訊也可以通過該方式注入模型中進行引導,火爆的Stable Diffusion就是這麼做的。
Diffusion Model 拓展 —— Stable Diffusion
Stable Diffusion 是2022年火爆全球的文圖生成(text-to-image)擴散模型,由於它開源,並且效果相當炸裂,因此已經被大多數人使用。
Stable Diffusion 背後的技術是LDM(latent diffusion model),之所以叫Stable Diffusion,或許與其背後的公司由Stability AI有關。
Stable Diffusion 是由CompVis、Stability AI和LAION三家公司共同創建,CompVis提供的技術LDM(latent diffusion model)源自論文《High-Resolution Image Synthesis with Latent Diffusion Models》,對應的github。LAION公司是一家致力於推動人工智慧和資料科學發展的科技公司,其從互聯網上抓取的 58 億「圖像-文本」資料,並開源了 LAION-5B資料集。而Stability AI的貢獻,或許是出錢出力出人吧。
Stable Diffusion 的開原始程式碼: https://github.com/CompVis/stable-diffusion 與 LDM(latent diffusion model)的開原始程式碼:https://github.com/CompVis/latent-diffusion都在CompVis下,代碼幾乎一樣。
下面簡要介紹Stable Diffusion用到的latent diffusion model技術。
LDM之前,擴散模型在圖元域進行擴散與去噪,這樣的計算量過大。因此,考慮將擴散過程放到隱空間(latent space),即將資料經過encoder,來到特徵空間,在特徵空間上進行擴散和去噪。
這樣一來,有以下好處:
- 計算量減小,訓練和推理速度變快
- 可以加入更多引導資訊,例如文本資訊。
LDM論文中有一幅圖很好的解釋了LDM的思想:首先在pixel space,需要有encoder和decoder,在latent space採用了多頭注意力機制,並且除了時間步資訊,加入了conditioning模組,其中的引導資訊可以是文本、圖片、表徵向量等等一切內容,然後為引導資訊配一個embedding模組,就可以將引導資訊加入模型中。
這裡配上李巨集毅老師的結構示意圖,可知道LDM的核心在於2當中,處理的不再是圖元空間,而是一個特徵空間
stable diffusion 的使用與安裝,網上有太多教程,這裡不進行介紹,主要瞭解LDM的架構。推薦閱讀:文生圖模型之Stable Diffusion
與Stable Diffusion同一時期,叫得上名稱的文圖生成模型還有Midjourney、DALL-E 2,不過它們都是不開源的。
小結
本案例借助DDPM的代碼剖析,瞭解擴散模型實現去噪,從而生成圖像的過程和原理,並且對Guided Diffusion Model進行介紹,模型要能根據我們的“指示”生成特定的圖像,這樣的模型才有更大的應用價值。
在Guided Diffusion Model中,包含classifier-base 和 classifier-free,classifier-free是後來的主流。
classifier-free的代表即出圈的Stable Diffusion,Stable Diffusion是完全開源的,因此得到了全球的使用與關注。
在擴散模型中,LDM(latent diffusion model)思想值得仔細研究,它將一切資訊都放到隱空間(特徵空間)進行處理,使得圖片處理起來更小,還可以進行多模態處理。