韓松等提出FlashMoBA,比MoBA快7.4倍,序列擴到512K也不會溢出
今年 2 月,月之暗面提出了一種名為 MoBA 的注意力機制,即 Mixture of Block Attention,可以直譯為「塊注意力混合」。
據介紹,MoBA 是「一種將混合專家(MoE)原理應用于注意力機制的創新方法。」該方法遵循「更少結構」原則,并不會引入預定義的偏見,而是讓模型自主決定關注哪些位置。
MoBA 在處理長上下文時表現出極強的潛力,它允許 Query 只稀疏地關注少量 Key-Value 塊,從而大幅降低計算成本。
然而,目前業界對 MoBA 性能背后的設計原則仍缺乏深入理解,同時也缺少高效的 GPU 實現,這限制了其實際應用。
在這篇論文中,來自 MIT、NVIDIA 機構的研究者首先建立了一個統計模型,用于分析 MoBA 的內部機制。模型顯示,其性能關鍵取決于路由器是否能夠基于 Query-Key 的相似度,準確區分相關塊與無關塊。研究者進一步推導出一個信噪比,將架構參數與檢索準確率建立起形式化聯系。
基于這一分析,本文識別出兩條主要的改進路徑:一是采用更小的塊大小,二是在 Key 上應用短卷積,使語義相關信號在塊內聚集,從而提升路由準確性。
然而,盡管小塊尺寸在理論上更優,但在現有的 GPU 實現中,小塊會導致嚴重的內存訪問碎片化和低并行度,速度甚至慢于稠密注意力。
為解決這一矛盾,研究者進一步提出了 FlashMoBA,一種硬件友好的 CUDA kernel,可在小塊配置下仍然高效地執行 MoBA。
結果顯示優化后的 MoBA 在性能上可與密集注意力基線相匹敵。對于小塊場景,FlashMoBA 相比 FlashAttention-2 可實現最高 14.7 倍加速。

- 論文地址:https://arxiv.org/pdf/2511.11571
- 項目地址:https://github.com/mit-han-lab/flash-moba
- 論文標題:OPTIMIZING MIXTURE OF BLOCK ATTENTION
FLASHMOBA:一種面向小塊 MoBA 的優化內核
理論模型表明,較小的塊尺寸能帶來顯著的質量提升,但樸素的 GPU 實現效率低下。由月之暗面發布的原始 MoBA 實現,在配置小塊尺寸時會遭遇性能瓶頸,這些瓶頸抵消了稀疏性帶來的計算節省,導致執行速度比稠密注意力更慢。
研究者推出了 FlashMoBA,這是一種硬件感知的 CUDA 內核,旨在使小塊 MoBA 變得實用且高效。
小塊帶來的性能挑戰
小塊尺寸引入了幾個關鍵的性能挑戰,要在實際部署中應用必須解決這些問題。
首先,在為每個查詢收集稀疏、不連續的鍵值塊時,會出現低效的內存訪問,導致從 HBM 讀取數據時出現非合并內存讀取。
其次,隨著較小的塊尺寸
導致路由器必須評分的塊數量(
)增加,Top-k 選擇和門控的開銷變得棘手。原始實現顯式生成了一個巨大的
分數矩陣,產生了巨大的內存開銷。
最后,由于每個塊的工作量減少以及啟動大量獨立內核的開銷,導致 GPU 占用率低,進而造成并行度差和硬件利用率低。
FLASHMOBA 內核設計
為了克服這些挑戰,FlashMoBA 采用了三個融合內核,以最大限度地減少 HBM 往返次數,并使計算與 GPU 架構相對齊,如圖 1 所示。
分塊 Top-K 選擇
Top-k 選擇過程是原始 MoBA 實現中的主要瓶頸,該實現顯式生成了完整的分數矩陣并串行處理批次序列。研究者將其替換為 Flash TopK(圖 1 中的步驟 1),這是一個由融合內核組成的高度優化的三階段流水線。

首先,一個 Triton 內核計算鍵塊的質心,生成一個更小的矩陣
。
其次,受 FlashAttention-2 啟發的分塊內核通過計算
和
之間的分數來為每個查詢找到 Top-k 個鍵塊,且無需將完整的分數矩陣顯式寫入 HBM,如算法 3 所述。

最后,一個高效的后處理步驟將以查詢為中心的索引重新格式化為以鍵塊為中心的變長布局,以便進行主注意力傳遞。整個流水線在批次和注意力頭之間完全并行化,消除了原始的性能瓶頸。
采用「收集并致密化」策略的前向傳播
為了處理 MoBA 的不規則稀疏性,前向內核使用了一種基于兩級分塊機制的「收集并致密化」策略,詳見算法 1。

要區分兩種類型的塊:
邏輯塊:內核在其外層循環中迭代的大型連續查詢塊
和鍵塊
。一個邏輯鍵塊對應一個 MoBA 鍵塊。
物理塊:加載到 SRAM 中用于矩陣乘法的較小圖塊(Tiles,例如
或
。它們的最佳尺寸取決于 GPU 架構和注意力頭的維度。
內核將一個邏輯查詢塊
分配給每個線程塊,并遍歷所有邏輯鍵塊
。對于每一對塊,它使用變長索引來查找相關的查詢。該子集被分批處理成稠密的物理塊:從 HBM 收集物理查詢塊并放入稠密 SRAM 緩沖區進行計算。
這種兩級方法是關鍵所在,因為在 SRAM 中緩存查詢允許在邏輯鍵塊的所有物理圖塊之間復用數據,從而通過高效的稠密 GEMM(通用矩陣乘法)分攤昂貴的不規則內存訪問成本。
帶重計算的反向傳播
反向傳播利用了 FlashAttention-2 的內存高效設計,并實現為三個內核的序列(算法 5)。

主內核在鍵維度上并行化計算,每個線程塊處理一個鍵塊。為了處理稀疏性,它鏡像了前向傳播的「收集并致密化」策略,使用變長索引收集查詢子集并將梯度輸出到片上圖塊中。
遵循 FlashAttention-2 的方法,研究者在反向傳播期間重計算注意力分數,以避免將完整的注意力矩陣存儲在內存中。雖然鍵和值的梯度直接寫入 HBM,但部分查詢梯度
需要跨多個鍵塊進行累加,這是通過對高精度全局緩沖區使用原子加法來高效且安全地處理的。
這種設計確保了反向傳播在序列長度上保持線性復雜度,這是相對于標準注意力的二次復雜度的一個關鍵改進。由于反向傳播通常構成優化注意力實現的主要性能瓶頸(通常比前向傳播慢 2-3 倍),因此我們需要反向內核的高效率對于實現長序列的實際訓練至關重要。
實驗及結果
本文從零開始預訓練模型,并進行可控實驗來驗證 MoBA 的設計原則。實驗共訓練了兩個模型,所有實驗均在 8× H100 80GB GPU 上完成:
- 340M 參數模型(hidden size 1024,16 heads,中間層規模 2816);
- 1B 參數模型(hidden size 2048,32 heads,中間層規模 8192)。
質量評估結果
本文在語言建模、長上下文檢索以及真實任務上對 MoBA 的表現進行了評估。實驗結果表明,改進后的模型在多種基準測試中提高了性能。
首先是塊大小的影響。圖 2 展示了塊大小對 340M 模型在 WikiText 困惑度(perplexity)和 RULER 準確率上的影響。正如
的理論預測,將塊大小從 512 縮小到 128,使困惑度從 20.9 降至 19.7,RULER 準確率從 38.8% 提升到 56.0%。更小的塊能夠幫助路由器更精準地識別相關內容。

這一趨勢在所有基準和不同模型規模上都保持一致。對 340M 模型來說,將塊大小從 512 縮小到原來的 1/4 到 128,可帶來如下提升:
- 語言建模準確率從 44.6% 提升到 45.6%(表 1);
- RULER 準確率從 38.8% 提升到 63.9%(表 3);
- LongBench 綜合得分從 13.2 提升到 15.3(表 5)。



總體來看,小塊尺寸對于 MoBA 達到與密集注意力相當的性能是必要的。
Key Convolution 。Key Convolution 在不同任務中都能帶來性能提升,而且具有任務偏好特性。對于 340M 模型:
- kconv3 將語言建模準確率從 45.1% 提升到 45.6%(表 1);
- kconv5 在 64K 長度檢索任務中達到 100% 的檢索率(表 3);
- 在 LongBench 上,kconv3 得分達到 15.3%(表 5)。
對于 1B 模型:
- kconv3 將語言建模準確率提升到 52.7%(表 2);
- 將 RULER 準確率提升到 68.2%(表 4)。


這些結果表明,卷積通過使相關 token 在塊內聚集,提升了有效均值差異
,從而顯著提高路由準確性。
注:卷積核寬度 W∈{3,5},分別記作 kconv3 和 kconv5。
稀疏匹配密集注意力機制。在多個基準測試和規模下,MoBA 的表現與密集注意力機制相當甚至更勝一籌。

效率結果
雖然理論上小塊尺寸能夠帶來更高的模型質量,但此前由于 GPU 利用率低下,小塊一直難以在實際中使用。FlashMoBA 的出現讓這些配置真正變得可行。
端到端性能。圖 3 對比了不同序列長度(8K 至 512K token)下的延遲和內存占用。FlashMoBA 在兩項指標上都顯著優于原始實現。
在 N=64K 且 B=128 的配置下:FlashMoBA 比原始 MoBA 快 7.4 倍,內存占用減少 6.1 倍,原始 MoBA 在 128K 序列就會 OOM(內存溢出),而 FlashMoBA 能擴展到 512K。
隨著序列越長、塊越小,優勢更明顯,因為 FlashMoBA 消除了全局 reindex 的開銷,在長序列條件下可實現最高 14.7× 快于 FlashAttention-2 的速度。

為了理解 FlashMoBA 的提速來源,圖 4 展示了在 N=64K 下前向傳播的耗時分布。
原始 MoBA 包含 5 個階段:(1)計算質心并執行 top-k、(2)全局 reindex、(3)在路由后的索引上執行注意力、(4)局部因果注意力以及(5)合并結果。
其中步驟 (1)、(2)、(5) 占據了超過 70% 的執行時間。
FlashMoBA 則使用兩個融合 kernel,這種融合設計將 64K 序列下的前向傳播時間降至 49 ms,而 FlashAttention-2 在相同設置下為 99 ms。































