【TVM 教程】向 Relay 中添加 Compiler Pass 原創
Apache TVM 是一個深度的深度學習編譯框架,適用于 CPU、GPU 和各種機器學習加速芯片。更多 TVM 中文文檔可訪問 →https://tvm.hyper.ai/
Compiler Pass 是擴展 Relay 功能集及優化 Relay 程序的主要接口。通過編寫 compiler pass,用戶可以基于最終目標,修改 AST 或收集 AST 相關信息。事實上,Relay 內置的一些重要特性(如自動微分和類型推斷)都“標準”的 compiler pass。
整體來看,編寫 pass 包括兩個關鍵組成部分:
- 創建一個或多個遍歷程序的 C++ 類
- 將遍歷實現及其在 pass manager API 中的元數據包裝,從而方便與?Pass Infrastructure?輕松交互
首先,我們將概述編寫 compiler pass 的關鍵機制。然后通過 Relay 中常量折疊 pass 的具體示例進行演示。
AST 遍歷器(Traversers)
用于遍歷 Relay 程序的基類是?ExprFunctor。它提供的公共接口是一個?VisitExpr?方法,該方法接收一個表達式以及零個或多個參數,并返回某種類型的實例。擴展此類時,可以通過覆蓋每種表達式類型的?VisitExpr_?實現,來定義 AST 遍歷模式。
VisitExpr?和?VisitExpr_?之間的關系與調度有關。每個?VisitExpr_?定義都針對特定類型的表達式,但用戶無法每次都得知要訪問的節點類型。為了解決這個問題,ExprFunctor?提供了一個?VisitExpr?函數,將給定表達式路由轉換為?VisitExpr_?實例進而解決問題。盡管 C++ 已經提供了動態調度,但?ExprFunctor?定義了自己的虛表供?VisitExp?使用。通過定義虛表可以更好地控制調度。例如,定義一個在每次訪問之前都打印 “Here” 的?PrintVisitor?遍歷器,可以覆蓋?VisitExpr:
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << "Here" << std::endl;
ExprFunctor::VisitExpr(expr);
}
ExprFunctor?本身是一個非常通用的類,這就是為什么更多時候你會擴展?ExprVisitor?或?ExprMutator。這些類擴展了?ExprFunctor,并提供了?VisitExpr_?的默認實現,這些實現捕獲了每種表達式類型的常見遍歷模式。有了這些默認的實現,開發者只需針對想要不同行為的表達式類型,提供覆蓋的實現。后續章節將針對每個子類進行詳細描述。
表達式訪問器(Expression Visitors)
ExprVisitor?不用于修改程序的pass,而是用于實施程序分析和收集信息的 pass。使用這個類,VisitExpr?和私有 counterparts 不會返回任何內容。此類提供的?VisitExpr_?實現只是訪問表達式的所有表達式字段。?IfNode?的默認實現如下所示:
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
注意,這里調用的是?VisitExpr?而非?VisitExpr_,因此用戶可以使用?ExprFunctor?中的虛表進行路由。
如果要編寫一個?CallChecker?類來檢查程序中是否出現函數調用,只需擴展?ExprVisitor?并定義以下?VisitExpr_?方法:
void VisitExpr_(const CallNode* n) final {
result_ = true;
}
其中?result_?是一個字段。在該示例中,無需在?CallNode?字段上進一步遞歸,因為?result_?已經為 true,原始表達式中包含一個調用。為了使該訪問器可用,可以采用以下方法:
bool Check(const Expr& expr) final {
result_ = false;
VisitExpr(expr);
return result_;
}
以上就是全部操作。在調用 top-level 的遞歸之前,定義一個執行一些記錄的公有接口是很常見的操作。用戶也可以通過創建一個生成?CallChecker?實例,并在其上調用?Check?的獨立程序來進一步包裝 API,重要的是用盡可能少的資源用實現目標。
表達式修改器(Expression Mutators)
ExprMutator?用于以某種方式轉換程序的 pass。通過這個類,VisitExpr?及其對應的私有部分返回?Expr。此類提供的默認?VisitExpr_?實現訪問表達式的所有表達式字段,并將字段設置為訪問它們的結果。TupleGetItemNode?的默認實現如下所示:
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItem(t, g->index);
}
}
這里有幾點需要注意。首先,Mutate?是?ExprMutator?中?VisitExpr?的別名。其次,如果對?Mutate?的調用修改了?tuple?字段,則只返回一個新節點。這種更新的方法稱為功能更新,這樣做可以避免不必要的分配。
ExprMutator?有、而?ExprVisitor?沒有的一個功能,是用于緩存結果的內置?memo_?字段。ExprMutator?有一個記憶器(memoizer)這是合理的,因為用戶知道正在緩存哪些類型的結果(即?Expr),而?ExprVisitor?的訪問方法不返回任何內容。通常,當用戶要在?ExprVisitor?的子類中緩存結果時,需要自行定義緩存。
如果希望編寫一個?IfCollapser?類,用它的真實分支替換每個 if 語句,用戶將為?IfNode?覆蓋?VisitExpr_:
Expr ExprMutator::VisitExpr_(const IfNode* op) {
return this->Mutate(op->true_branch);
}
注意:返回的表達式不一定是?IfNode,這是正常的,因為返回類型是?Expr。接下來創建一個公有接口:
Expr CollapseIfs(const Expr& expr) final {
return this->Mutate(expr);
}
雖然使用這個修改器無需做任何記錄,但仍然鼓勵用戶將描述性方法作為接口。
示例:常量折疊
為了更好地理解編寫 pass 的過程,本部分將以常量折疊 pass(可在?src/relay/transforms/fold_constant.cc?中找到)作為示例進行講解。常量折疊 pass 相對簡單,且包含兩種類型的遍歷。
常量折疊涉及只包含常量的程序評估表達式(evaluating expression),然后用評估它們的結果替換這些表達式。此過程的目的是預加載可以進行的所有計算。為了實現這一點,常量折疊 pass 使用了一個訪問器(ConstantChecker)和一個修改器(ConstantFolder)。
ConstantChecker?訪問器
此訪問器用于檢查表達式是否為常量。在 Relay 中,用戶將?ConstantNode?或者只有常量字段的?TupleNode?的表達式定義為常量。
使用?memo_?字段從節點映射到它們是否為常量,并緩存這些結果。下面是?ConstantChecker?中的?VisitExpr_?定義。
void VisitExpr_(const ConstantNode* n) final {
memo_[GetRef<Constant>(n)] = true;
}
void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef<Tuple>(n)] = result;
}
用于協調這些定義的記錄是一個?Check?方法,它返回給定的表達式是否被認定為常量。
bool Check(const Expr& expr) {
const auto it = memo_.find(expr);
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr];
}
并不是所有遇到的節點都要修改?memo_;相反,用戶只有在遇到的節點有可能是常數時,才修改?memo_。當?memo_?不包含?expr?時,需要依賴默認的 false 值。
ConstantFolder?修改器
這個修改器執行了大部分的常量折疊過程,并在內部使用?ConstantChecker。在 Relay 中,常量折疊涉及三種節點類型:LetNode、TupleItemGetNode?和?CallNode。后續段落中將進行詳細講解。
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let(var, value, body);
}
}
}
在?LetNode?示例里,首先嘗試常量折疊綁定在表達式的值。如果可以,填充?memo_?并返回訪問主體的結果——本質上是將綁定的值傳到主體中的使用點。如果無法常量折疊綁定的值,可以參照默認的實現方法:
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
}
}
在?TupleItemGetNode?的例子里,需要檢查?op->tuple?字段是否為?TupleNode。如果是,我們將 get 元組替換為?op->index?指向的元組的字段。這樣做的原因是因為?op->tuple?可能被錯誤評估為一個元組。
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
// 我們不使用零參數的常量折疊函數。
// 這是一個很有用的啟發式方法。
// 例如折疊那些 shape=(4, 5) 是有害的。
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
// 跳過有狀態的算子。
if (op_stateful.get(GetRef<Op>(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
在?CallNode?示例中,首先使用?ExprMutator?的?VisitExpr_?來訪問調用,它將調用的所有字段都常量折疊了。之所以使用?ExprMutator::VisitExpr_?而不是?VisitExpr,是因為我們想要繞過虛表(以避免死循環)并使用?ExprMutator?提供的默認實現。只有當所有參數都是常量時,才評估調用(使用?ConstantChecker)。評估調用會產生一個值,因此這里使用輔助方法?ValueToExpr?,將評估的表達式放回 AST 中。
現在,我們為常量文件夾構造了一個更方便的接口?FoldConstant。FoldConstant?是?ConstantFolder?類之外的一個獨立函數,它負責接收表達式并在內部創建和使用?ConstantFolder?實例(其完整的定義在?src/relay/transforms/fold_constant.cc?中)。
用 Pass Manager 注冊 Pass
*注意:更多詳情請參閱?Pass Infrastructure?中的文檔。
編寫 AST 遍歷器后,用以下代碼可將 pass 注冊為 TVM API 端點:
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
} // 命名空間轉換
將上述代碼生成的?Pass?對象提供給 pass 基礎架構,可以使得 AST 遍歷應用于給定 Relay 模塊中的所有函數,這是常量折疊過程預期的行為(它應該盡可能折疊所有常量)。
函數?CreateFunctionPass?允許注冊 pass 的優化級別(在本例中為 2),可用于根據 pass 的一般實用性、 pass 名稱和 pass 中的任何依賴項將 pass 組合在一起。pass 的依賴項以列表形式給出,羅列了當前 pass 運行所必需的所有 pass 的結果。FoldConstant?沒有任何依賴,但是很多 Relay pass 確實依賴有類型信息,所以?InferType?是一個常見的依賴;其他的可能依賴于程序為 A-范式,通過?ToANormalForm?pass。
注意,PassContext?對象包含 pass 用于錯誤報告和配置選項的信息;?FoldConstant?不需要此信息,但其他 pass 可能會引用它們的?PassContext?對象。
現在可以通過 pass 基礎結構調用 pass 了,推薦為 pass 添加 Python 綁定,如以下代碼片段所示:
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
通過以上方法定義了?Pass?對象后,就可以用 pass 基礎架構的?Sequential?結構來調用了。?Sequential?接收一個 pass 列表,并將其按順序應用于 Relay 模塊,從而獲得轉換后的模塊。例如,下面的代碼將?FoldConstant?和?ToANormalForm?pass 逐一應用于?mod?中的每個函數,并獲得一個新模塊。
seq = transform.Sequential([
relay.transform.FoldConstant(),
relay.transform.ToANormalForm()
])
new_mod = seq(mod)
更多注冊相關的內容,請查看?TVM Runtime 系統;pass 管理器接口相關的更多信息,請查看?Pass 基礎架構; Relay 的標準 pass 列表及實現方式,請分別查看?include/tvm/relay/transform.h?及?src/relay/transforms/。

















