千行代碼搞定Transformer?這份高效PaddlePaddle官方實(shí)現(xiàn)請(qǐng)收下(千行代碼bug率多少算合格)
目前,無(wú)論是從性能、結(jié)構(gòu)還是業(yè)界應(yīng)用上,Transformer 都有很多無(wú)可比擬的優(yōu)勢(shì)。本文將介紹 PaddlePaddle 的 Transformer 項(xiàng)目,我們從項(xiàng)目使用到源碼解析帶你玩一玩 NMT。只需千行模型代碼,Transformer 實(shí)現(xiàn)帶回家。
其實(shí) PyTorch、TensorFlow 等主流框架都有 Transformer 的實(shí)現(xiàn),但如果我們需要將它們應(yīng)用到產(chǎn)品中,還是需要修改很多。
例如谷歌大腦構(gòu)建的 Tensor2Tensor,它最開(kāi)始是為了實(shí)現(xiàn) Transformer,后來(lái)擴(kuò)展到了各種任務(wù)。對(duì)于基于 Tensor2Tensor 實(shí)現(xiàn)翻譯任務(wù)的用戶,他們需要在 10 萬(wàn) 行 TensorFlow 代碼找到需要的部分。
PaddlePaddle 提供的 Transformer 實(shí)現(xiàn),項(xiàng)目代碼只有 2000 行,簡(jiǎn)潔優(yōu)雅。如果我們使用大 Batch Size,那么在預(yù)測(cè)速度上,PaddlePaddle 復(fù)現(xiàn)的模型比 TensorFlow 官方使用 Tensor2Tensor 實(shí)現(xiàn)的模型還要快 4 倍。
項(xiàng)目地址:https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleNLP/neural_machine_translation/transformer
1. Transformer 怎么用
相比此前 Seq2Seq 模型中廣泛使用的循環(huán)神經(jīng)網(wǎng)絡(luò),Transformer 使用深層注意力機(jī)制獲得了更好的效果,目前大多數(shù)神經(jīng)機(jī)器翻譯模型都采用了這一網(wǎng)絡(luò)結(jié)構(gòu)。此外,不論是新興的預(yù)訓(xùn)練語(yǔ)言模型,還是問(wèn)答或句法分析,Transformer 都展現(xiàn)出強(qiáng)大的建模能力。
相比傳統(tǒng) NMT 使用循環(huán)層或卷積層抽取文本信息,Transformer 使用自注意力網(wǎng)絡(luò)抽取并表征這些信息,下圖對(duì)比了不同層級(jí)的特點(diǎn):
不同網(wǎng)絡(luò)的主要性質(zhì),其中 n 表示序列長(zhǎng)度、d 為隱向量維度、k 為卷積核大小。例如單層計(jì)算復(fù)雜度,一般句子長(zhǎng)度 n 都小于隱向量維度 d,那么自注意力層級(jí)的計(jì)算復(fù)雜度最小。
如上所示,Transformer 使用的自注意力模型主要擁有以下優(yōu)點(diǎn),1)網(wǎng)絡(luò)結(jié)構(gòu)的計(jì)算復(fù)雜度最低;2)由于序列操作數(shù)復(fù)雜度低,模型的并行度很高;3)最大路徑長(zhǎng)度小,能夠更好地表示長(zhǎng)距離依賴關(guān)系;4)模型更容易訓(xùn)練。
現(xiàn)在,如果我們需要訓(xùn)練一個(gè) Transformer,那么最好的方法是什么?當(dāng)然是直接跑已復(fù)現(xiàn)的模型了,下面我們將跑一跑 PaddlePaddle 實(shí)現(xiàn)的 Transformer。
1.1 處理數(shù)據(jù)
在 PaddlePaddle 的復(fù)現(xiàn)中,百度采用原論文測(cè)試的 WMT’16 EN-DE 數(shù)據(jù)集,它是一個(gè)中等規(guī)模的數(shù)據(jù)集。這里比較方便的是,百度將數(shù)據(jù)下載和預(yù)處理等過(guò)程都放到了 gen_data.sh 腳本中,包括 Tokenize 和 BPE 編碼。
在這個(gè)項(xiàng)目中,我們既可以通過(guò)腳本預(yù)處理數(shù)據(jù),也可以使用百度預(yù)處理好的數(shù)據(jù)集。首先最簡(jiǎn)單的方式是直接運(yùn)行 gen_data.sh 腳本,運(yùn)行后可以生成 gen_data 文件夾,該文件夾主要包含以下文件:
其中 wmt16_ende_data_bpe 文件夾包含最終使用的英德翻譯數(shù)據(jù)。
如果我們從頭下載并預(yù)處理數(shù)據(jù),那么大概需要花 1 到 2 個(gè)小時(shí)完成預(yù)處理。為此,百度也提供了預(yù)處理好的 WMT’16 EN-DE 數(shù)據(jù)集,它包含訓(xùn)練、驗(yàn)證和測(cè)試所需要的 BPE 數(shù)據(jù)和字典。
其中,BPE 策略會(huì)把稀疏詞拆分為高頻的子詞,這樣既能解決低頻詞無(wú)法訓(xùn)練的問(wèn)題,也能合理降低詞表規(guī)模。
如果不采用 BPE 的策略,要么詞表的規(guī)模變得很大,從而使訓(xùn)練速度變慢或者顯存太小而無(wú)法訓(xùn)練;要么一些低頻詞會(huì)當(dāng)作未登錄詞處理,從而得不到訓(xùn)練。
預(yù)處理數(shù)據(jù)地址:https://transformer-res.bj.bcebos.com/wmt16_ende_data_bpe_clean.tar.gz
如果我們有其它數(shù)據(jù)集,例如中英翻譯數(shù)據(jù),也可以根據(jù)特定的格式進(jìn)行定義。例如用空格分隔不同的 token(對(duì)于中文而言需要提前用分詞工具進(jìn)行分詞),用t 分隔源語(yǔ)言與目標(biāo)語(yǔ)句對(duì)。
1.2 訓(xùn)練模型
如果需要執(zhí)行模型訓(xùn)練,我們也可以直接運(yùn)行訓(xùn)練主函數(shù) train.py。如下簡(jiǎn)要配置了數(shù)據(jù)路徑以及各種模型參數(shù):
# 顯存使用的比例,顯存不足可適當(dāng)增大,最大為1export FLAGS_fraction_of_gpu_memory_to_use=0.8# 顯存清理的閾值,顯存不足可適當(dāng)減小,最小為0,為負(fù)數(shù)時(shí)不啟用export FLAGS_eager_delete_tensor_gb=0.7python -u train.py –src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 –trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 –special_token ‘<s>’ ‘<e>’ ‘<unk>’ –train_file_pattern gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de –token_delimiter ‘ ‘ –use_token_batch True –batch_size 1600 –sort_type pool –pool_size 200000 n_head 8 n_layer 4 d_model 512 d_inner_hid 1024 prepostprocess_dropout 0.3
此外,如果顯存不夠大,那么我們可以將 Batch Size 減小一點(diǎn)。為了快速測(cè)試訓(xùn)練效果,我們將模型調(diào)得比 Base Transformer 還?。ń档途W(wǎng)絡(luò)的層數(shù)、head 的數(shù)量、以及隱層的大小)。
上面僅展示了小部分的超參設(shè)置,更多的配置可以在 GitHub 項(xiàng)目 config.py 文件中找到。默認(rèn)情況下,模型每迭代一萬(wàn)次保存一次模型,每個(gè) epoch 結(jié)束后也會(huì)保存一次 cheekpoint。此外,在我們訓(xùn)練的過(guò)程中,默認(rèn)每一百次迭代會(huì)打印一次模型信息,其中 ppl 表示的是困惑度,困惑度越小模型效果越好。
在單機(jī)訓(xùn)練中,默認(rèn)使用所有 GPU,可以通過(guò) CUDA_VISIBLE_DEVICES 環(huán)境變量來(lái)設(shè)置使用的 GPU,例如 CUDA_VISIBLE_DEVICES=’0,1’,表示使用 0 號(hào)和 1 號(hào)卡進(jìn)行訓(xùn)練。
1.3 預(yù)測(cè)推斷
訓(xùn)練完 Transformer 后就可以執(zhí)行推斷了,我們需要運(yùn)行對(duì)應(yīng)的推斷文件 infer.py。我們也可以在推斷過(guò)程中配置超參數(shù),但注意超參需要和前面訓(xùn)練時(shí)保持一致。
python -u infer.py –src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 –trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 –special_token ‘<s>’ ‘<e>’ ‘<unk>’ –test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de –token_delimiter ‘ ‘ –batch_size 32 model_path trained_models/iter_100000.infer.model n_head 8 n_layer 4 d_model 512 d_inner_hid 1024 prepostprocess_dropout 0.3 beam_size 5 max_out_len 255
相比模型的訓(xùn)練,推斷過(guò)程需要一些額外的超參數(shù),例如配置 model_path 指定模型所在目錄、設(shè)置 beam_size 和 max_out_len 來(lái)指定 Beam Search 每一步候選詞的個(gè)數(shù)和最大翻譯長(zhǎng)度。這些超參數(shù)也可以在 config.py 中找到,該文件對(duì)這些超參都有注釋說(shuō)明。
執(zhí)行以上預(yù)測(cè)命令會(huì)將翻譯結(jié)果直接打出來(lái),每行輸出是對(duì)應(yīng)行輸入得分最高的翻譯。對(duì)于使用 BPE 的英德數(shù)據(jù),預(yù)測(cè)出的翻譯結(jié)果也將是 BPE 表示的數(shù)據(jù),所以需要還原成原始數(shù)據(jù)才能進(jìn)行正確評(píng)估。如下命令可以將 predict.txt 內(nèi)的翻譯結(jié)果(BPE 表示)恢復(fù)到 predict.tok.txt 文件中(tokenize 后的數(shù)據(jù)):
sed -r ‘s/(@@ )|(@@ ?$)//g’ predict.txt > predict.tok.txt
在未使用集成方法的情況下,百度表示 base model 和 big model 在收斂后,測(cè)試集的 BLEU 值參考如下:
這兩個(gè)預(yù)訓(xùn)練模型也提供了下載地址:
- Base:https://transformer-res.bj.bcebos.com/base_model.tar.gz
- Big:https://transformer-res.bj.bcebos.com/big_model.tar.gz
2. Transformer 怎么改
如果我們想要訓(xùn)練自己的 Transformer,那么又該怎樣理解并修改 PaddlePaddle 代碼呢?如果我們需要根據(jù)自己的數(shù)據(jù)集和任務(wù)改代碼,除了前面數(shù)據(jù)預(yù)處理過(guò)程,模型結(jié)構(gòu)等模塊有時(shí)也需要修改。這就需要我們先理解源代碼了,PaddlePaddle 的源代碼基本都是基礎(chǔ)的函數(shù)或運(yùn)算,我們很容易理解并使用。
對(duì)于 PaddlePaddle 不熟悉的讀者可查閱文檔,也可以看看入門教程,了解基本編寫模式后就可以看懂整個(gè)實(shí)現(xiàn)了。
PaddlePaddle 官網(wǎng)地址:http://paddlepaddle.org/paddle
如 Seq2Seq 一樣,原版 Transformer 也采用了編碼器-解碼器框架,但它們會(huì)使用多個(gè) Multi-Head 注意力、前饋網(wǎng)絡(luò)、層級(jí)歸一化和殘差連接等。下圖從左到右展示了原論文所提出的 Transformer 架構(gòu)、Multi-Head 注意力和標(biāo)量點(diǎn)乘注意力。
上圖右邊的點(diǎn)乘注意力就是標(biāo)準(zhǔn) Seq2Seq 模型中的注意力機(jī)制,中間的 Multi-head 注意力其實(shí)就是將一個(gè)隱層信息切分為多份,并單獨(dú)計(jì)算注意力信息,使得一個(gè)詞與其它多個(gè)目標(biāo)詞的注意力信息計(jì)算更精確。最左邊為 Transformer 的整體架構(gòu),編碼器與解碼器由多個(gè)類似的模塊組成,后面將簡(jiǎn)要介紹這些模塊與對(duì)應(yīng)的 PaddlePaddle 代碼。
2.1 點(diǎn)乘注意力
注意力機(jī)制目前在機(jī)器翻譯中已經(jīng)極其流行了,我們可以認(rèn)為 Transformer 是一種堆疊多層注意力網(wǎng)絡(luò)的模型,它采用的是一種名為經(jīng)縮放的點(diǎn)乘注意力機(jī)制。這種注意力機(jī)制使用經(jīng)縮放的點(diǎn)乘作為作為評(píng)分函數(shù),從而評(píng)估各隱藏狀態(tài)對(duì)當(dāng)前預(yù)測(cè)的重要性,如下是該注意力的表達(dá)式:
其中 Query 向量與 (Key, Value ) 向量在 NMT 中相當(dāng)于目標(biāo)語(yǔ)輸入序列與源語(yǔ)輸入序列,Query 與 Key 向量的點(diǎn)乘,經(jīng)過(guò) SoftMax 函數(shù)后可得出一組歸一化的概率。這些概率相當(dāng)于給源語(yǔ)輸入序列做加權(quán)平均,即表示在生成新的隱層信息的時(shí)候需要關(guān)注哪些詞。
在 Transformer 的 PaddlePaddle 實(shí)現(xiàn)中,經(jīng)縮放的點(diǎn)乘注意力是在 Multi-head 注意力函數(shù)下實(shí)現(xiàn)的,如下所示為上述表達(dá)式的實(shí)現(xiàn)代碼:
在這個(gè)函數(shù)中,q、k、v 和公式中的一樣,attn_bias 用于 Mask 掉選定的特定位置(encode 的 self attention 和 decoder 端的 encode attention 都是屏蔽掉 padding 的詞;decoder 的 self attention 屏蔽掉當(dāng)前詞后面的詞,目的是為了和解碼的過(guò)程保持一致),因此在給不同輸入加權(quán)時(shí)忽略該位置的輸入。
如上 product 計(jì)算的是 q 和 k 之間的點(diǎn)乘,且經(jīng)過(guò)根號(hào)下 d_key(key 的維度)的縮放。這里我們可以發(fā)現(xiàn)參數(shù) alpha 可以直接對(duì)矩陣乘法的結(jié)果進(jìn)行縮放,默認(rèn)情況下它為 1.0,即不進(jìn)行縮放。在 Transformer 原論文中,作者表示如果 d_key 比較小,那么直接點(diǎn)乘和帶縮放的點(diǎn)乘差別不大,所以他們認(rèn)為高維情況下可能不帶縮放的乘積太大而令 Softmax 函數(shù)飽和。
weights 表示對(duì)輸入的不同元素加權(quán),即不同輸入對(duì)當(dāng)前預(yù)測(cè)的重要性,訓(xùn)練中也可以對(duì)該權(quán)重進(jìn)行 Dropout。最后 out 表示按照 weights 對(duì)輸入 V 進(jìn)行加權(quán)和,得出來(lái)就是當(dāng)前注意力的運(yùn)算結(jié)果。
2.2 Muti-head 注意力
Multi-head 注意力其實(shí)就是多個(gè)點(diǎn)乘注意力并行地處理并最后將結(jié)果拼接在一起。一般而言,我們可以對(duì)三個(gè)輸入矩陣 Q、V、K 分別進(jìn)行線性變換,然后分別將它們投入 h 個(gè)點(diǎn)乘注意力函數(shù)并拼接所有的輸出結(jié)果。
這種注意力允許模型聯(lián)合關(guān)注不同位置的不同表征子空間信息,我們可以理解為在參數(shù)不共享的情況下,多次執(zhí)行點(diǎn)乘注意力。如下所示為 Muti-head 注意力的表達(dá)式:
其中每一個(gè) head 都為一個(gè)點(diǎn)乘注意力,不同 head 的輸入是相同 Q、K、V 的不同線性變換。
總體而言,PaddlePaddle 的 Multi-head 注意力實(shí)現(xiàn)分為幾個(gè)步驟:先為 Q、K、V 執(zhí)行線性變換;再變換維度以計(jì)算點(diǎn)乘注意力;最后計(jì)算各 head 的注意力輸出并合并在一起。
2.2.1 線性變換
如前公式所示,Muti-head 首先要執(zhí)行線性變換,從而令不同的 head 關(guān)注不同表征空間的信息。這種線性變換即乘上不同的權(quán)重矩陣,且模型在訓(xùn)練過(guò)程中可以學(xué)習(xí)和更新這些權(quán)重矩陣。在如下的 PaddlePaddle 代碼中,我們可以直接調(diào)用全連接層 layers.fc() 完成線性變換。
直接調(diào)用全連接層會(huì)自動(dòng)為輸入創(chuàng)建權(quán)重,且我們要求不使用偏置項(xiàng)和激活函數(shù)。這里比較方便的是,PaddlePaddle 的 layers.fc() 函數(shù)可以接受高維輸入,省略了手動(dòng)展平輸入向量的操作。因此這里有 num_flatten_dims=2,即將前兩個(gè)維度展平為一個(gè)維度,第三個(gè)維度保持不變。
例如對(duì)于輸入張量 q 而言,線性變換的輸出維度應(yīng)該是 [batch_size,max_sequence_length,d_key * n_head],最后一個(gè)維度即 n_head 個(gè) d_key 維的 Query 向量。每一個(gè) d_key 維的向量都會(huì)饋送到不同的 head,并最后拼接起來(lái)。
2.2.2 維度變換
為了進(jìn)行 Multi-Head 的運(yùn)算,我們需要將線性變換的結(jié)果進(jìn)行 reshape 和轉(zhuǎn)置操作?,F(xiàn)在我們將這幾個(gè)張量的最后一個(gè)維度分割成不同的 head,并做轉(zhuǎn)置以便于后續(xù)運(yùn)算。
具體而言,輸入張量 q、k 和 v 的維度信息為 [bs, max_sequence_length, n_head * hidden_dim],我們希望把它們轉(zhuǎn)換為 [bs, n_head, max_sequence_length, hidden_dim]。
如上使用 layers.reshape() 和 layers.transpose() 函數(shù)完成分割與轉(zhuǎn)置。其中 layers.reshape() 在接收輸入張量后會(huì)按照形狀 [0, 0, n_head, d_key] 進(jìn)行轉(zhuǎn)換,其中 0 表示從輸入張量對(duì)應(yīng)維數(shù)復(fù)制出來(lái)。此外,因?yàn)?inplace 設(shè)置為 True,那么 reshape 操作就不會(huì)進(jìn)行數(shù)據(jù)的復(fù)制,從而提升運(yùn)算效率。
后面的轉(zhuǎn)置就比較簡(jiǎn)單了,只需要按照維度索引將第「1」個(gè)維度和第「2」個(gè)維度交換就行了。此外為了更快地執(zhí)行推斷,PaddlePaddle 實(shí)現(xiàn)代碼還做了非常多的優(yōu)化,例如這部分后續(xù)會(huì)對(duì)推斷過(guò)程的緩存和處理流程進(jìn)行優(yōu)化。
2.2.3 合并
前面已經(jīng)介紹過(guò)點(diǎn)乘注意力了,那么上面對(duì) q、k、v 執(zhí)行維度變換后就可直接傳入點(diǎn)乘注意力函數(shù),并計(jì)算出 head_1、head_2 等注意力結(jié)果?,F(xiàn)在最后一步只需要將這些 head 拼接起來(lái)就完成了整個(gè)過(guò)程,也就完成了上面 Multi-head 注意力的計(jì)算式。
因?yàn)槊恳粋€(gè)批量、head 和時(shí)間步都會(huì)計(jì)算得出一個(gè)注意力向量,因此總體上注意力計(jì)算結(jié)果的維度信息為 [bs, n_head, max_sequence_length, hidden_dim]。如果要將不同的 head 拼接在一起,即將 head 這個(gè)維度合并到 hidden_dim 中去,因此合并的過(guò)程和前面維度變換的過(guò)程正好相反。
如上合并過(guò)程會(huì)先檢驗(yàn)維度信息,然后先轉(zhuǎn)置再 reshape 合并不同的 head。注意在原論文中,合并不同的 head 后,還需要再做一個(gè)線性變換,這個(gè)線性變換的結(jié)果就是 Muti-head 注意力的輸出了。
最后,我們?cè)賹⑸厦娴乃牟糠执饋?lái)就是 Transformer 最核心的 Multi-head 注意力。理解了各個(gè)模塊后,下面串起來(lái)就能愉快地看懂整個(gè)過(guò)程了:
當(dāng)然,如果編碼器和解碼器輸入到 Multi-head 注意力的 q 與 (k、v) 是相同的,那么它又可稱為自注意力網(wǎng)絡(luò)。
2.3 前饋網(wǎng)絡(luò)
對(duì)于每一個(gè)編碼器和解碼器模塊,除了殘差連接與層級(jí)歸一化外,重要的就是堆疊 Muti-head 注意力和前饋網(wǎng)絡(luò)(FFN)。前面我們已經(jīng)解決了 Multi-head 注意力,現(xiàn)在需要理解主位置的前饋網(wǎng)絡(luò)了。直觀而言,F(xiàn)FN 的作用是整合 Multi-head 注意力生成的上下文向量,因此能更好地利用從源語(yǔ)句子和目標(biāo)語(yǔ)句子抽取的深度信息。
如下所示在原論文中,前饋網(wǎng)絡(luò)的計(jì)算過(guò)程可以表達(dá)為以下方程:
前饋網(wǎng)絡(luò)的結(jié)構(gòu)很簡(jiǎn)單,一個(gè) ReLU 激活函數(shù)加兩次線性變換就完成了。如下基本上只需要調(diào)用PaddlePaddle 的 layers.fc() 就可以了:
現(xiàn)在基本上核心操作就定義完了,后面還有更多模塊與架構(gòu),例如怎樣利用核心操作搭建編碼器模塊與解碼器模塊、如何搭建整體 Transformer 模型等,讀者可繼續(xù)閱讀原項(xiàng)目中的簡(jiǎn)潔代碼。整體而言,包括上面代碼在內(nèi),千行代碼就可以完全弄懂 Transformer,PaddlePaddle 的 Transformer 復(fù)現(xiàn)值得我們仔細(xì)讀一讀。
此外,在這千行模型代碼中,為了給訓(xùn)練和推斷加速,還有很多特殊技巧。例如在 Decoder 中加入對(duì) Encoder 計(jì)算結(jié)果的緩存等。加上這些技巧,PaddlePaddle 的實(shí)現(xiàn)才能在大 Batch Size 下實(shí)現(xiàn) 4 倍推斷加速。
因?yàn)楸旧?PaddlePaddle 代碼就已經(jīng)非常精煉,通過(guò)它們也很容易理解這些技巧?;旧峡春瘮?shù)名稱就能知道大致的作用,再結(jié)合文檔使用就能完全讀懂了。
最后,除了模型架構(gòu),整個(gè)項(xiàng)目還會(huì)有其它組成部分,例如訓(xùn)練、推斷、數(shù)據(jù)預(yù)處理等等。這些代碼同樣非常簡(jiǎn)潔,我們可以根據(jù)實(shí)際需求閱讀并修改它們??傮w而言,PaddlePaddle 的 Transformer 實(shí)現(xiàn)確實(shí)非常適合理解與修改。想要跑一跑神經(jīng)機(jī)器翻譯的同學(xué),PaddlePaddle 的 Transformer 實(shí)現(xiàn)確實(shí)值得推薦。