9.4 機器翻譯—Transformer
前言
本節將介紹當下人工智慧領域的基石與核心結構模型——Transformer,為什麼說它是基石,因為以ChatGPT為代表的聊天機器人以及各種有望通向AGI(通用人工智慧)的道路上均在採用的Transformer。
Transformer也是當下NLP任務的底座,包括後續的BERT和GPT,都是Transformer架構,BERT主要由Transformer的encoder構成,GPT則主要是decoder構成。
本節將會通讀Transformer原論文《Attention is all you need》,總結Transformer結構以及文章要點,然後採用pytorch代碼進行機器翻譯實驗,通過代碼進一步瞭解Transformer在應用過程中的步驟。
論文閱讀筆記
Transformer的論文是《Attention is all you need》(https://arxiv.org/abs/1706.03762),由Google團隊在2017提出的一種針對機器翻譯的模型結構,後續在各類NLP任務上均獲取SOTA。
Motivation:針對RNN存在的計算複雜、無法串列的問題,提出僅通過attention機制的簡潔結構——Transformer,在多個序列任務、機器翻譯中獲得SOTA,並有運用到其他資料模態的潛力。
模型結構
模型由encoder和decoder兩部分構成,分別採用了6個block堆疊,
- encoder的block有兩層,multi-head attention和FC層
- decoder的block有三層,處理自回歸的輸入masked multi-head attention,處理encoder的attention,FC層。
注意力機制:scale dot-production attention,採用QK矩陣乘法縮放後softmax充當權重,再與value進行乘法。
多頭注意力機制:實驗發現多個頭效果好,block的輸出,把多個頭向量進行concat,然後加一個FC層。因此每個頭的向量長度是總長度/頭數,例:512/8=64,每個頭是64維向量。
Transformer的三種注意力層:
-
- encoder:輸入來自上一層的全部輸出
- decoder-輸入:為避免模型未卜先知,只允許看見第i步之前的資訊,需要做mask操作,確保在生成序列的每個元素時,只考慮該元素之前的元素。這裡通過softmax位置設置負無窮來控制無效的連接。
- decoder-第二個attention:q來自上一步輸出,k和v來自encoer的輸出。這樣解碼器可以看到輸入的所有序列。
FFN層:attention之後接入兩個FC層,第一個FC層採用2048,第二個是512,第一個FC層採用max(0, x)作為啟動函數。
embedding層的縮放:在embedding層進行尺度縮放,乘以根號d_model,
位置編碼:採用正余弦函數構建位置向量,然後採用加法,融入embedding中。
實驗:兩個任務,450萬句子(英-德),3600萬句子(英-法),8*16GB顯卡,分別訓練12小時,84小時。10萬step時,採用4kstep預熱;採用標籤平滑0.1.
重點句子摘錄
- In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section 3.2.
由於注意力機制最後的加權平均可能會序列中各位置的細細微性捕捉不足,因此引入多頭注意力機制。 這裡是官方對多頭注意力機制引入的解釋。
- \2. We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 根號dk
在Q和K做點積時,若向量維度過長,會導致點積結果過大,再經softmax映射後,梯度會變小,不利於模型學習,因此需要進行縮放。縮放因數為除以根號dk。
- 多頭注意力機制的三種情況:
-
- encoder:輸入來自上一層的全部輸出
- decoder-輸入:為避免模型未卜先知,只允許看見第i步之前的資訊,需要做mask操作,確保在生成序列的每個元素時,只考慮該元素之前的元素。這裡通過softmax位置設置負無窮來控制無效的連接。
- decoder-第二個attention:q來自上一步輸出,k和v來自encoer的輸出。這樣解碼器可以看到輸入的所有序列。
首次閱讀遺留問題
- 位置編碼是否重新學習?
不重新學習,詳見 PositionalEncoding的 self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
- qkv具體實現過程
- 通過3個w獲得3個特徵向量:
- attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) # q.shape == (bs, n_head, len_seq, d_k/n_head) ,每個token用 一個向量表示,總向量長度是頭數*每個頭的向量長度。
- attn = self.dropout(F.softmax(attn, dim=-1))
- output = torch.matmul(attn, v)
- decoder輸入的處理細節
- 訓練階段,無特殊處理,一個樣本可以直接輸入,根據下三角mask避免未卜先知
- 推理階段,首先手動執行token的輸入,然後for迴圈直至最大長度,期間輸出的token拼接到輸出列表中,並作為下一步decoder的輸入。選擇輸出token時還採用了beam search來有效地平衡廣度和深度。
資料集構建
數據下載
本案例資料集Tatoeba下載自這裡。該專案是説明不同語言的人學習英語,因此是英語與其它幾十種語言的翻譯文本。
其中就包括本案例使用的英中文本,共計29668條(Mandarin Chinese - English cmn-eng.zip (29668))
資料以txt形式存儲,一行是一對翻譯文本,例如長這樣:
1.That mountain is easy to climb. 那座山很容易爬。
2.It's too soon. 太早了。
3.Japan is smaller than Canada. 日本比加拿大小。
資料集劃分
對於29668條資料進行8:2劃分為訓練、驗證,這裡採用配套代碼a_data_split.py進行劃分,即可在統計目錄下獲得train.txt和text.txt。
詞表構建
文本任務首要任務是為文本構建詞表,這裡採用與上節一樣的方法,首先對文本進行分詞,然後統計語料庫中所有的詞,最後根據最大上限、最小詞頻等約束,構建詞表。本部分配套代碼是b_gen_vocabulary.py
詞表的構建過程中,涉及兩個知識點:中文分詞和特殊token。
1. 中文分詞
對於英文,分詞可以直接採用空格。而對於中文,就需要用特定的分詞方法,這裡採用的是jieba分詞工具,以下是英文和中文的分詞代碼。
source.append(parts[0].split(' '))
target.append(list(jieba.cut(parts[1]))) # 分詞
Copy
2. 特殊token
由於seq2seq任務的特殊性,在解碼器部分,通常需要一個token告訴模型,現在是開始,同時還需要有個token讓模型輸出,以此告訴人類,模型輸出完畢,不要再繼續生成了。
因此相較于文本分類,還多了bos, eos,兩個特殊token,有的時候,開始token也會用start表示。
PAD_TAG = "<pad>" # 用PAD補全句子長度
BOS_TAG = "<bos>" # 用BOS表示開始
EOS_TAG = "<eos>" # 用EOS表示結束
UNK_TAG = "<unk>" # 用EOS表示結束
PAD = 0 # PAD字元對應的數位
BOS = 1 # BOS字元對應的數位
EOS = 2 # EOS字元對應的數位
UNK = 3 # UNK字元對應的數位
Copy
運行代碼後,詞表字典保存到了result目錄下,並得到如下輸出,表明英文中有2518個詞,中文有3365,但經過最大長度3000的截斷後,只剩下2996,另外4個是特殊token。
100%|██████████| 23635/23635 [00:00<00:00, 732978.24it/s]
原始詞表長度:2518,截斷後長度:2518
2522
保存詞頻統計圖:vocab_en.npy_word_freq.jpg
100%|██████████| 23635/23635 [00:00<00:00, 587040.62it/s]
保存統計圖:vocab_en.npy_length_freq.jpg
原始詞表長度:3365,截斷後長度:2996
3000
Copy
Dataset編寫
NMTDataset的編寫邏輯與上一小節的Dataset類似,首先在類初始化的時候載入原始資料,並進行分詞;在getitem反覆運算時,再進行token轉index操作,這裡會涉及增加結束符、填充符、未知符。
核心代碼如下:
def __init__(self, path_txt, vocab_path_en, vocab_path_fra, max_len=32):
self.path_txt = path_txt
self.vocab_path_en = vocab_path_en
self.vocab_path_fra = vocab_path_fra
self.max_len = max_len
self.word2index = WordToIndex()
self._init_vocab()
self._get_file_info()
def __getitem__(self, item):
# 獲取切分好的句子list,一個元素是一個詞
sentence_src, sentence_trg = self.source_list[item], self.target_list[item]
# 進行填充, 增加結束符,索引轉換
token_idx_src = self.word2index.encode(sentence_src, self.vocab_en, self.max_len)
token_idx_trg = self.word2index.encode(sentence_trg, self.vocab_fra, self.max_len)
str_len, trg_len = len(sentence_src) + 1, len(sentence_trg) + 1 # 有效長度, +1是填充的結束符 <eos>.
return np.array(token_idx_src, dtype=np.int64), str_len, np.array(token_idx_trg, dtype=np.int64), trg_len
def _get_file_info(self):
text_raw = read_data_nmt(self.path_txt)
text_clean = text_preprocess(text_raw)
self.source_list, self.target_list = text_split(text_clean)
Copy
模型構建
Transformer代碼梳理如下圖所示,大體可分為三個層級
- Transformer的構建,包含encoder、decoder兩個模組,以及兩個mask構建函數
- 兩個coder內部實現,包括位置編碼、堆疊的block
- block實現,包含多頭注意力、FFN,其中多頭注意力將softmax(QK)*V拆分為ScaledDotProductAttention類。
代碼整體與論文中保持一致,總結幾個不同之處:
- layernorm使用時機前置到attention層和FFN層之前
- position embedding 的序列長度預設採用了200,如果需要更長的序列,則要注意配置。
具體代碼實現不再贅述,把論文中圖2的結果熟悉,並掌握上面的代碼結構,可以快速理解各模組、元件的運算和操作步驟,如有疑問的點,再打開代碼觀察具體運算過程即可。
模型訓練
原論文進行兩個資料集的機器翻譯任務,採用的資料和超參數列舉如下,供參考。
英語-德語,450萬句子對,英語-法語,3600萬句子對。均進行base/big兩種尺寸訓練,分別進行10萬step和30萬step訓練,耗時12小時/84小時。10萬step時,採用4千step進行warmup。正則化方面採用了dropout=0.1的殘差連接,0.1的標籤平滑。
本實驗中有2.3萬句子對訓練,只能作為Transformer的學習,性能指標僅為參考,具體任務後續由BERT、GPT、T5來實現更為具體的項目。
採用配套代碼train_transformer.py,執行訓練即可:
python train_transformer.py -embs_share_weight -proj_share_weight -label_smoothing -b 256 -warmup 128000 -epoch 400
Copy
訓練完成後,在result資料夾下會得到日誌與模型,接下來採用配套代碼c_train_curve_plot.py對日誌資料進行繪圖視覺化,Loss和Accuracy分別如下,可以看出模型擬合能力非常強,性能還在提高,但受限於資料量過少,模型在200個epoch之後就已經出現了過擬合。
這裡面的評估指標用的acc,具體是直接複用github專案,也能作為模型性能的評估指標,就沒有去修改為BLUE。
模型推理
Transformer的推理過程與訓練時不同,值得仔細學習。
Transformer的推理輸出是典型的自回歸(Auto regressive),並且需要根據多個條件綜合判斷何時停止,因此推理部分的邏輯值得認真學習,具體步驟如下:
第一步:輸入序列經encoder,獲得特徵,每個token輸出一個向量,這個特徵會在decoder的每個step都用到,即decoder中各block的第二個multi-head attention。需要enc_output去計算k和v,用decoder上一層輸出特徵向量去計算q,以此進行decoder的第二個attention。
enc_output, *_ = self.model.encoder(src_seq, src_mask)
Copy
第二步:手動執行decoder第一步,輸入是這個token,輸出的是一個概率向量,由分類概率向量再決定第一個輸出token。
self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]))
dec_output = self._model_decode(self.init_seq, enc_output, src_mask)
Copy
第三步:迴圈max_length次執行decoder剩餘步。第i步時,將前序所有步輸出的token,組裝為一個序列,輸入到decoder。
在代碼中用一個gen_seq維護模型輸出的token,輸入給模型時,只需要gen_seq[:, :step]即可,很巧妙。
在每一步輸出時,都會判斷是否需要停止輸出字元。
for step in range(2, max_seq_len):
dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask)
略
if (eos_locs.sum(1) > 0).sum(0).item() == beam_size:
_, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)
ans_idx = ans_idx.item()
break
Copy
借助李宏毅老師2021年課件中一幅圖,配合代碼,可以很好的理解Transformer在推理時的流程。
- 輸入序列經encoder,獲得特徵(綠色、藍色、藍色、橙色)
- decoder輸入第一個序列(序列長度為1,token是),輸出第一個概率向量,並通過max得到“機”。
- decoder輸入第二個序列(序列長度為2,token是[BOS, 機]),輸出得到“器”
- decoder輸入第三個序列(序列長度為3,token是[BOS, 機,器]),輸出得到“學”
- decoder輸入第四個序列(序列長度為4,token是[BOS, 機,器,學]),輸出得到“習”
- ...以此類推
在推理過程中,通常還會使用Beam Search 來最終確定當前步應該輸出哪個token,此處不做展開。
運行配套代碼inference_transformer.py可以看到10條訓練集的推理結果。
從結果上看,基本像回事兒了。
src: tom's door is open .
trg: 湯姆 的 門開 著 。
pred: 湯姆 的 <unk> 了 。
src: <unk> is a <unk> country .
trg: <unk> 是 一個 <unk> 的 國家 。
pred: <unk> 是 一個 <unk> 的 城市 。
src: i can come at three .
trg: <unk> 可以 來 。
pred: 我 可以 在 那裡 。
Copy
小結
本節通過Transformer論文的學習,瞭解Transformer基礎架構,並通過機器翻譯案例,從代碼實現的角度深入剖析Transformer訓練和推理的過程。由於Transformer是目前人工智慧的核心與基石,因此需要認真、仔細的掌握其中細節。
本節內容值得注意的知識點:
- 多頭注意力機制的引入:官方解釋為,由於注意力機制最後的加權平均可能會序列中各位置的細細微性捕捉不足,因此引入多頭注意力機制
- 注意力計算時的縮放因數:QK乘法後需要縮放,是因為若向量維度過長,會導致點積結果過大,再經softmax映射後,梯度會變小,不利於模型學習,因此需要進行縮放。縮放因數為除以根號dk。
- 多頭注意力機制的三種情況:
- encoder:輸入來自上一層的全部輸出
- decoder-輸入:為避免模型未卜先知,只允許看見第i步之前的資訊,需要做mask操作,確保在生成序列的每個元素時,只考慮該元素之前的元素。這裡通過softmax位置設置負無窮來控制無效的連接。
- decoder-第二個attention:q來自上一步輸出,k和v來自encoer的輸出。這樣解碼器可以看到輸入的所有序列。
- 輸入序列的msk:代碼實現時,由於輸入資料是通過添加來組batch的,並且為了在運算時做並行運算,因此需要對src中是pad的token做mask,這一點在論文是不會提及的。
- decoder的mask:根據下三角mask避免未卜先知。
- 推理時會採用beam search進行搜索,確定輸出的token。
留言列表