微軟新出熱乎論文:Transformer擴(kuò)展到10億token
當(dāng)大家不斷升級(jí)迭代自家大模型的時(shí)候,LLM(大語(yǔ)言模型)對(duì)上下文窗口的處理能力,也成為一個(gè)重要評(píng)估指標(biāo)。
比如明星大模型 GPT-4 支持 32k token,相當(dāng)于 50 頁(yè)的文字;OpenAI 前成員創(chuàng)立的 Anthropic 更是將 Claude 處理 token 能力提升到 100k,約 75000 個(gè)單詞,大概相當(dāng)于一鍵總結(jié)《哈利波特》第一部。
在微軟最新的一項(xiàng)研究中,他們這次直接將 Transformer 擴(kuò)展到 10 億 token。這為建模非常長(zhǎng)的序列開(kāi)辟了新的可能性,例如將整個(gè)語(yǔ)料庫(kù)甚至整個(gè)互聯(lián)網(wǎng)視為一個(gè)序列。
作為比較,普通人可以在 5 小時(shí)左右的時(shí)間里閱讀 100,000 個(gè) token,并可能需要更長(zhǎng)的時(shí)間來(lái)消化、記憶和分析這些信息。Claude 可以在不到 1 分鐘的時(shí)間里完成這些。要是換算成微軟的這項(xiàng)研究,將會(huì)是一個(gè)驚人的數(shù)字。
圖片
- 論文地址:https://arxiv.org/pdf/2307.02486.pdf
- 項(xiàng)目地址:https://github.com/microsoft/unilm/tree/master
具體而言,該研究提出了 LONGNET,這是一種 Transformer 變體,可以將序列長(zhǎng)度擴(kuò)展到超過(guò) 10 億個(gè) token,而不會(huì)犧牲對(duì)較短序列的性能。文中還提出了 dilated attention,它能指數(shù)級(jí)擴(kuò)展模型感知范圍。
LONGNET 具有以下優(yōu)勢(shì):
1)它具有線性計(jì)算復(fù)雜性;
2)它可以作為較長(zhǎng)序列的分布式訓(xùn)練器;
3)dilated attention 可以無(wú)縫替代標(biāo)準(zhǔn)注意力,并可以與現(xiàn)有基于 Transformer 的優(yōu)化方法無(wú)縫集成。
實(shí)驗(yàn)結(jié)果表明,LONGNET 在長(zhǎng)序列建模和一般語(yǔ)言任務(wù)上都表現(xiàn)出很強(qiáng)的性能。
在研究動(dòng)機(jī)方面,論文表示,最近幾年,擴(kuò)展神經(jīng)網(wǎng)絡(luò)已經(jīng)成為一種趨勢(shì),許多性能良好的網(wǎng)絡(luò)被研究出來(lái)。在這當(dāng)中,序列長(zhǎng)度作為神經(jīng)網(wǎng)絡(luò)的一部分,理想情況下,其長(zhǎng)度應(yīng)該是無(wú)限的。但現(xiàn)實(shí)卻往往相反,因而打破序列長(zhǎng)度的限制將會(huì)帶來(lái)顯著的優(yōu)勢(shì):
- 首先,它為模型提供了大容量的記憶和感受野,使其能夠與人類(lèi)和世界進(jìn)行有效的交互。
- 其次,更長(zhǎng)的上下文包含了更復(fù)雜的因果關(guān)系和推理路徑,模型可以在訓(xùn)練數(shù)據(jù)中加以利用。相反,較短的依賴關(guān)系則會(huì)引入更多虛假的相關(guān)性,不利于模型的泛化性。
- 第三,更長(zhǎng)的序列長(zhǎng)度可以幫助模型探索更長(zhǎng)的上下文,并且極長(zhǎng)的上下文也可幫助模型緩解災(zāi)難性遺忘問(wèn)題。
然而,擴(kuò)展序列長(zhǎng)度面臨的主要挑戰(zhàn)是在計(jì)算復(fù)雜性和模型表達(dá)能力之間找到合適的平衡。
例如 RNN 風(fēng)格的模型主要用于增加序列長(zhǎng)度。然而,其序列特性限制了訓(xùn)練過(guò)程中的并行化,而并行化在長(zhǎng)序列建模中是至關(guān)重要的。
最近,狀態(tài)空間模型對(duì)序列建模非常有吸引力,它可以在訓(xùn)練過(guò)程中作為 CNN 運(yùn)行,并在測(cè)試時(shí)轉(zhuǎn)換為高效的 RNN。然而這類(lèi)模型在常規(guī)長(zhǎng)度上的表現(xiàn)不如 Transformer。
另一種擴(kuò)展序列長(zhǎng)度的方法是降低 Transformer 的復(fù)雜性,即自注意力的二次復(fù)雜性。現(xiàn)階段,一些高效的基于 Transformer 的變體被提出,包括低秩注意力、基于核的方法、下采樣方法、基于檢索的方法。然而,這些方法尚未將 Transformer 擴(kuò)展到 10 億 token 的規(guī)模(參見(jiàn)圖 1)。
圖片
下表為不同計(jì)算方法的計(jì)算復(fù)雜度比較。N 為序列長(zhǎng)度,d 為隱藏維數(shù)。
圖片
方法
該研究的解決方案 LONGNET 成功地將序列長(zhǎng)度擴(kuò)展到 10 億個(gè) token。具體來(lái)說(shuō),該研究提出一種名為 dilated attention 的新組件,并用 dilated attention 取代了 Vanilla Transformer 的注意力機(jī)制。通用的設(shè)計(jì)原則是注意力的分配隨著 token 和 token 之間距離的增加而呈指數(shù)級(jí)下降。該研究表明這種設(shè)計(jì)方法獲得了線性計(jì)算復(fù)雜度和 token 之間的對(duì)數(shù)依賴性。這就解決了注意力資源有限和可訪問(wèn)每個(gè) token 之間的矛盾。
圖片
在實(shí)現(xiàn)過(guò)程中,LONGNET 可以轉(zhuǎn)化成一個(gè)密集 Transformer,以無(wú)縫地支持針對(duì) Transformer 的現(xiàn)有優(yōu)化方法(例如內(nèi)核融合(kernel fusion)、量化和分布式訓(xùn)練)。利用線性復(fù)雜度的優(yōu)勢(shì),LONGNET 可以跨節(jié)點(diǎn)并行訓(xùn)練,用分布式算法打破計(jì)算和內(nèi)存的約束。
最終,該研究有效地將序列長(zhǎng)度擴(kuò)大到 1B 個(gè) token,而且運(yùn)行時(shí)(runtime)幾乎是恒定的,如下圖所示。相比之下,Vanilla Transformer 的運(yùn)行時(shí)則會(huì)受到二次復(fù)雜度的影響。

該研究進(jìn)一步引入了多頭 dilated attention 機(jī)制。如下圖 3 所示,該研究通過(guò)對(duì)查詢 - 鍵 - 值對(duì)的不同部分進(jìn)行稀疏化,在不同的頭之間進(jìn)行不同的計(jì)算。
圖片
分布式訓(xùn)練
雖然 dilated attention 的計(jì)算復(fù)雜度已經(jīng)大幅降低到
,但由于計(jì)算和內(nèi)存的限制,在單個(gè) GPU 設(shè)備上將序列長(zhǎng)度擴(kuò)展到百萬(wàn)級(jí)別是不可行的。有一些用于大規(guī)模模型訓(xùn)練的分布式訓(xùn)練算法,如模型并行 [SPP+19]、序列并行 [LXLY21, KCL+22] 和 pipeline 并行 [HCB+19],然而這些方法對(duì)于 LONGNET 來(lái)說(shuō)是不夠的,特別是當(dāng)序列維度非常大時(shí)。
該研究利用 LONGNET 的線性計(jì)算復(fù)雜度來(lái)進(jìn)行序列維度的分布式訓(xùn)練。下圖 4 展示了在兩個(gè) GPU 上的分布式算法,還可以進(jìn)一步擴(kuò)展到任意數(shù)量的設(shè)備。

實(shí)驗(yàn)
該研究將 LONGNET 與 vanilla Transformer 和稀疏 Transformer 進(jìn)行了比較。架構(gòu)之間的差異是注意力層,而其他層保持不變。研究人員將這些模型的序列長(zhǎng)度從 2K 擴(kuò)展到 32K,與此同時(shí)減小 batch 大小,以保證每個(gè) batch 的 token 數(shù)量不變。
表 2 總結(jié)了這些模型在 Stack 數(shù)據(jù)集上的結(jié)果。研究使用復(fù)雜度作為評(píng)估指標(biāo)。這些模型使用不同的序列長(zhǎng)度進(jìn)行測(cè)試,范圍從 2k 到 32k 不等。當(dāng)輸入長(zhǎng)度超過(guò)模型支持的最大長(zhǎng)度時(shí),研究實(shí)現(xiàn)了分塊因果注意力(blockwise causal attention,BCA)[SDP+22],這是一種最先進(jìn)的用于語(yǔ)言模型推理的外推方法。
此外,研究刪除了絕對(duì)位置編碼。首先,結(jié)果表明,在訓(xùn)練過(guò)程中增加序列長(zhǎng)度一般會(huì)得到更好的語(yǔ)言模型。其次,在長(zhǎng)度遠(yuǎn)大于模型支持的情況下,推理中的序列長(zhǎng)度外推法并不適用。最后,LONGNET 一直優(yōu)于基線模型,證明了其在語(yǔ)言建模中的有效性。

序列長(zhǎng)度的擴(kuò)展曲線
圖 6 繪制了 vanilla transformer 和 LONGNET 的序列長(zhǎng)度擴(kuò)展曲線。該研究通過(guò)計(jì)算矩陣乘法的總 flops 來(lái)估計(jì)計(jì)算量。結(jié)果表明,vanilla transformer 和 LONGNET 都能從訓(xùn)練中獲得更大的上下文長(zhǎng)度。然而,LONGNET 可以更有效地?cái)U(kuò)展上下文長(zhǎng)度,以較小的計(jì)算量實(shí)現(xiàn)較低的測(cè)試損失。這證明了較長(zhǎng)的訓(xùn)練輸入比外推法更具有優(yōu)勢(shì)。實(shí)驗(yàn)表明,LONGNET 是一種更有效的擴(kuò)展語(yǔ)言模型中上下文長(zhǎng)度的方法。這是因?yàn)?LONGNET 可以更有效地學(xué)習(xí)較長(zhǎng)的依賴關(guān)系。

擴(kuò)展模型規(guī)模
大型語(yǔ)言模型的一個(gè)重要屬性是:損失隨著計(jì)算量的增加呈冪律擴(kuò)展。為了驗(yàn)證 LONGNET 是否仍然遵循類(lèi)似的擴(kuò)展規(guī)律,該研究用不同的模型規(guī)模(從 1.25 億到 27 億個(gè)參數(shù)) 訓(xùn)練了一系列模型。27 億的模型是用 300B 的 token 訓(xùn)練的,而其余的模型則用到了大約 400B 的 token。圖 7 (a) 繪制了 LONGNET 關(guān)于計(jì)算的擴(kuò)展曲線。該研究在相同的測(cè)試集上計(jì)算了復(fù)雜度。這證明了 LONGNET 仍然可以遵循冪律。這也就意味著 dense Transformer 不是擴(kuò)展語(yǔ)言模型的先決條件。此外,可擴(kuò)展性和效率都是由 LONGNET 獲得的。

長(zhǎng)上下文 prompt
Prompt 是引導(dǎo)語(yǔ)言模型并為其提供額外信息的重要方法。該研究通過(guò)實(shí)驗(yàn)來(lái)驗(yàn)證 LONGNET 是否能從較長(zhǎng)的上下文提示窗口中獲益。
該研究保留了一段前綴(prefixes)作為 prompt,并測(cè)試其后綴(suffixes)的困惑度。并且,研究過(guò)程中,逐漸將 prompt 從 2K 擴(kuò)展到 32K。為了進(jìn)行公平的比較,保持后綴的長(zhǎng)度不變,而將前綴的長(zhǎng)度增加到模型的最大長(zhǎng)度。圖 7 (b) 報(bào)告了測(cè)試集上的結(jié)果。它表明,隨著上下文窗口的增加,LONGNET 的測(cè)試損失逐漸減少。這證明了 LONGNET 在充分利用長(zhǎng)語(yǔ)境來(lái)改進(jìn)語(yǔ)言模型方面的優(yōu)越性。





























