大廠都在用它搞分布式AI!開源框架Ray到底牛在哪?
一、發展歷程
2016年,伯克利Rise實驗室的學生和研究人員開展了一項課程項目,旨在擴展分布式神經網絡訓練和強化學習的規模。
該項目催生了論文《Ray:面向新興AI應用的分布式框架》。
如今,Ray是一款開源分布式計算框架,能輕松實現Python機器學習工作負載的產品化和規模化部署。
它解決了分布式機器學習中的三大核心挑戰:
- 突破計算限制:可遠程訪問近乎無限的計算資源
- 具備容錯能力:自動將失敗任務重新路由到集群中的其他機器
- 狀態管理:支持任務間的數據共享和跨數據的協同
在深入探討之前,我們先了解為何需要Ray。
二、硬件危機
隨著大語言模型(LLM)和生成式AI的爆發式增長,計算資源的供需缺口日益擴大。
請看下方圖表:

Ray:現代 AI 技術棧的統一分布式框架
我們可以看到,機器學習系統的訓練計算需求每18個月增長10倍。
最先進(SOTA)模型的訓練需求與單核性能之間存在巨大差距。盡管專業硬件的性能已有顯著提升,但仍無法滿足計算需求,且這一缺口還將呈指數級擴大。
即便模型規模停止增長,專業硬件也需要數十年時間才能追趕上來。
當前最佳解決方案是對AI工作負載進行分布式處理。
同時這也帶來了新的挑戰。
三、AI應用的挑戰
構建AI應用需要開發者整合多個環節的工作負載,包括數據采集、預處理、訓練、微調、預測和部署。
這一過程極具挑戰性,因為每個環節都需要不同的系統,且每個系統都有其專屬的API、語義和約束條件。

借助Ray,只需一個系統即可支持所有這些工作負載。

Ray 庫套件——機器學習工作負載的統一工具包
根據官方文檔:
Ray的五個原生庫分別針對特定的機器學習任務提供分布式支持:
- Data:跨訓練、調優和預測環節,提供可擴展、框架無關的數據加載和轉換功能
- Train:支持多節點、多核的分布式模型訓練,具備容錯能力,可與主流訓練庫集成
- Tune:可擴展的超參數調優工具,用于優化模型性能
- Serve:可擴展、可編程的部署工具,支持在線推理模型部署,可選微批處理以提升性能
- RLlib:支持可擴展的分布式強化學習工作負載
四、企業如何使用Ray

OpenAI使用Ray協調ChatGPT的訓練工作。
Cohere結合PyTorch、JAX和TPU,通過Ray實現了大規模大語言模型訓練。
下圖展示了Alpa如何利用Ray為分布式訓練調度GPU資源:

Ray如何解決生成式 AI 與大語言模型的基礎設施挑戰
Ray解決了生成式模型分布式訓練中兩個最常見的挑戰:
- 如何在多個加速器之間有效劃分模型?
- 如何在搶占式實例上設置具備容錯能力的訓練流程?
Shopify、Spotify、Pinterest和Roblox等企業均借助Ray擴展其機器學習基礎設施:
- Shopify在其Merlin平臺中使用Ray,簡化了從原型設計到生產部署的機器學習工作流,利用Ray Train和Tune實現分布式訓練和超參數調優
- Spotify采用Ray進行并行模型訓練和調優,以優化其推薦系統
- Pinterest利用Ray實現高效的數據處理和可擴展的基礎設施管理
- 在Roblox,Ray支持混合云環境下的大規模AI推理,助力構建穩健、可擴展的機器學習解決方案
五、Ray的核心特性
極簡API
Ray的核心API僅包含6個調用:
ray.init() # 初始化
@ray.remote # 遠程函數裝飾器
def big_function():
...
futures = slow_function.remote() # 調用遠程函數
ray.get(futures) # 獲取返回對象
ray.put() # 將對象存儲到對象存儲中
ray.wait() # 獲取已就緒的對象
ray.shutdown() # 關閉連接以下是將Python Counter類轉換為異步函數的示例:

六、使用Ray
了解了Ray的功能后,來實際嘗試一下。
假設我們有一個??is_prime??函數,用于計算小于n的所有質數之和:
def is_prime(n):
if n < 2:
return False
for i in range(2, int(math.sqrt(n)) + 1):
if n % i == 0:
return False
return True
def sum_primes(limit):
return sum(num for num in range(2, limit) if is_prime(num))測試普通Python計算1000萬以內質數和的性能,并重復計算8次:
%%time
# 串行執行
n_calculations = 8
limit = 10_000_000
sequential_results = [sum_primes(limit) for _ in range(n_calculations)]
# CPU耗時:用戶態12分31秒,系統態7.56秒,總計12分39秒
# 墻鐘時間:12分58秒總計耗時13分鐘!
再看看Ray能帶來多大的速度提升:

%%time
# 并行執行
futures = [sum_primes.remote(limit) for _ in range(n_calculations)]
parallel_results = ray.get(futures)
# CPU耗時:用戶態477毫秒,系統態366毫秒,總計843毫秒
# 墻鐘時間:4分2秒速度提升了3倍!
點擊儀表盤鏈接,還可以查看任務進度:

Ray儀表盤的任務進度展示圖
本文轉載自??AI科技論談???,作者:AI科技論談

















