模型蒸餾:“學(xué)神”老師教出“學(xué)霸”學(xué)生 原創(chuàng) 精華
編者按: 近日,Qwen 3 技術(shù)報(bào)告正式發(fā)布,該系列也采用了從大參數(shù)模型中蒸餾知識(shí)來(lái)訓(xùn)練小參數(shù)模型的技術(shù)路線。那么,模型蒸餾技術(shù)究竟是怎么一回事呢?
今天給大家分享的這篇文章深入淺出地介紹了模型蒸餾的核心原理,即通過(guò)讓學(xué)生模型學(xué)習(xí)教師模型的軟標(biāo)簽而非硬標(biāo)簽,從而傳遞更豐富的知識(shí)信息。作者還提供了一個(gè)基于 TensorFlow 和 MNIST 數(shù)據(jù)集的完整實(shí)踐案例,展示了如何構(gòu)建教師模型和學(xué)生模型,如何定義蒸餾損失函數(shù),以及如何通過(guò)知識(shí)蒸餾方法訓(xùn)練學(xué)生模型。實(shí)驗(yàn)結(jié)果表明,參數(shù)量更少的學(xué)生模型能夠達(dá)到與教師模型相媲美的準(zhǔn)確率。
作者 | Wei-Meng Lee
編譯 | 岳揚(yáng)

Photo by 戸山 神奈 on Unsplash
如果你一直在關(guān)注 DeepSeek 的最新動(dòng)態(tài),可能聽(tīng)說(shuō)過(guò)“模型蒸餾”這個(gè)概念。但究竟什么是模型蒸餾?它為何重要?本文將解析模型蒸餾原理,并通過(guò)一個(gè) TensorFlow 示例進(jìn)行演示。通過(guò)閱讀這篇技術(shù)指南,我相信您將對(duì)模型蒸餾有更深刻的理解。
01 模型蒸餾技術(shù)原理
模型蒸餾通過(guò)讓較小的、較簡(jiǎn)單的模型(學(xué)生模型)學(xué)習(xí)模仿較大的、較復(fù)雜的模型(教師模型)的軟標(biāo)簽(而非原始標(biāo)簽),使學(xué)生模型能以更精簡(jiǎn)的架構(gòu)繼承教師模型的知識(shí),用更少參數(shù)實(shí)現(xiàn)相近性能。以圖像分類任務(wù)為例,學(xué)生模型不僅學(xué)習(xí)“某張圖片是狗還是貓”的硬標(biāo)簽,還會(huì)學(xué)習(xí)教師模型輸出的軟標(biāo)簽(如80%狗,15%貓,5%狐貍),從而掌握更細(xì)粒度的知識(shí)。 這一過(guò)程能在保持高準(zhǔn)確率的同時(shí)大大降低模型體積和計(jì)算資源需求。
下文我們將以使用 MNIST 數(shù)據(jù)集訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)(CNN)為例進(jìn)行演示。
MNIST 數(shù)據(jù)集(Modified National Institute of Standards and Technology)是機(jī)器學(xué)習(xí)和計(jì)算機(jī)視覺(jué)領(lǐng)域廣泛使用的基準(zhǔn)數(shù)據(jù)集,包含 70,000 張 28x28 像素的手寫數(shù)字(0-9)灰度圖像,其中 60,000 張訓(xùn)練圖像和 10,000 張測(cè)試圖像。
首先構(gòu)建教師模型:

Image by author
教師模型是基于 MNIST 訓(xùn)練的 CNN 網(wǎng)絡(luò)。
同時(shí)構(gòu)建更輕量的學(xué)生模型:

Image by author
模型蒸餾的目標(biāo)是通過(guò)更少的計(jì)算量和訓(xùn)練時(shí)間訓(xùn)練一個(gè)較小的學(xué)生模型,復(fù)現(xiàn)教師模型的性能表現(xiàn)。
接下來(lái),教師模型和學(xué)生模型同時(shí)對(duì)數(shù)據(jù)集進(jìn)行預(yù)測(cè),然后計(jì)算二者輸出的 Kullback-Leibler (KL) 散度(將于后文進(jìn)行詳述)。該數(shù)值(KL 散度)用于計(jì)算梯度,指導(dǎo)模型各層參數(shù)應(yīng)該如何調(diào)整,從而指導(dǎo)學(xué)生模型的參數(shù)更新:

Image by author
訓(xùn)練完成后,學(xué)生模型達(dá)到與教師模型相當(dāng)?shù)臏?zhǔn)確率:

Image by author
02 創(chuàng)建一個(gè)用于模型蒸餾的示例項(xiàng)目
現(xiàn)在,我們對(duì)模型蒸餾的工作原理已經(jīng)有了更清晰的理解,是時(shí)候通過(guò)一個(gè)簡(jiǎn)單的示例來(lái)了解如何實(shí)現(xiàn)模型蒸餾了。我將使用 TensorFlow 和 MNIST 數(shù)據(jù)集訓(xùn)練教師模型,然后應(yīng)用模型蒸餾技術(shù)訓(xùn)練一個(gè)較小的學(xué)生模型,使其在保持教師模型性能的同時(shí)降低資源需求。
2.1 使用 MNIST 數(shù)據(jù)集
確保已安裝 TensorFlow:

下一步加載 MNIST 數(shù)據(jù)集:

以下是從 MNIST 數(shù)據(jù)集中選取的前 9 個(gè)樣本圖像及其標(biāo)簽:

需要對(duì)圖像數(shù)據(jù)進(jìn)行歸一化處理,并擴(kuò)展圖像數(shù)據(jù)的維度,為訓(xùn)練做好準(zhǔn)備:

2.2 定義教師模型
現(xiàn)在我們來(lái)定義教師模型 —— 一個(gè)具有多個(gè)網(wǎng)絡(luò)層的 CNN(卷積神經(jīng)網(wǎng)絡(luò)):

請(qǐng)注意,學(xué)生模型的最后一層有 10 個(gè)神經(jīng)元(對(duì)應(yīng) 10 個(gè)數(shù)字類別),但未使用 softmax 激活函數(shù)。該層直接輸出原始 logits 值,這在模型蒸餾過(guò)程中非常重要,因?yàn)樵谀P驼麴s階段會(huì)應(yīng)用 softmax 計(jì)算教師模型與學(xué)生模型之間的 Kullback-Leibler(KL)散度。
定義完教師神經(jīng)網(wǎng)絡(luò)后,需通過(guò) compile() 方法配置優(yōu)化器(optimizer)、損失函數(shù)(loss function)和評(píng)估指標(biāo)(metric for evaluation):

現(xiàn)在可以使用 fit() 方法訓(xùn)練模型:

本次訓(xùn)練進(jìn)行了 5 個(gè)訓(xùn)練周期:

2.3 定義學(xué)生模型
在教師模型訓(xùn)練完成后,接下來(lái)定義學(xué)生模型。與教師模型相比,學(xué)生模型的結(jié)構(gòu)更簡(jiǎn)單、層數(shù)更少:

2.4 定義蒸餾損失函數(shù)
接下來(lái)定義蒸餾損失函數(shù),該函數(shù)將利用教師模型的預(yù)測(cè)結(jié)果和學(xué)生模型的預(yù)測(cè)結(jié)果計(jì)算蒸餾損失(distillation loss)。該函數(shù)需完成以下操作:
- 使用教師模型對(duì)當(dāng)前批次的輸入數(shù)據(jù)進(jìn)行推理,生成軟標(biāo)簽「硬標(biāo)簽:[0, 0, 1](直接指定類別3)。軟標(biāo)簽:[0.1, 0.2, 0.7](表示模型認(rèn)為70%概率是類別3,但保留其他可能性)。」;
- 使用學(xué)生模型預(yù)測(cè)計(jì)算其軟標(biāo)簽;
- 計(jì)算教師模型與學(xué)生模型軟標(biāo)簽之間的 Kullback-Leibler(KL)散度;
- 返回蒸餾損失。
軟標(biāo)簽(soft probabilities)指的是包含多種可能結(jié)果的概率分布,而非直接分配一個(gè)硬標(biāo)簽。例如在垃圾郵件分類模型中,模型不會(huì)直接判定郵件"是垃圾郵件(1)"或"非垃圾郵件(0)",而是輸出類似"垃圾郵件概率 0.85,非垃圾郵件概率 0.15"的概率分布。 這意味著模型有 85% 的把握認(rèn)為該郵件是垃圾郵件,但仍認(rèn)為有 15% 的可能性不是,從而可以更好地進(jìn)行決策和閾值調(diào)整。
軟標(biāo)簽使用 softmax 函數(shù)進(jìn)行計(jì)算,并由溫度參數(shù)(temperature)控制分布形態(tài)。在知識(shí)蒸餾過(guò)程中,教師模型提供的軟標(biāo)簽?zāi)軒椭鷮W(xué)生模型學(xué)習(xí)到數(shù)據(jù)集各類別間的隱含關(guān)聯(lián),從而獲得更優(yōu)的泛化能力和性能表現(xiàn)。
以下是 distillation_loss() 函數(shù)的具體定義:

Kullback-Leibler(KL)散度 (又稱相對(duì)熵)是衡量?jī)蓚€(gè)概率分布差異程度的數(shù)學(xué)方法。
2.5 使用知識(shí)蒸餾方法訓(xùn)練學(xué)生模型
現(xiàn)在我們可以通過(guò)知識(shí)蒸餾訓(xùn)練學(xué)生模型了。首先定義 train_step() 函數(shù):

該函數(shù)只執(zhí)行了一個(gè)訓(xùn)練步驟:
- 計(jì)算學(xué)生模型的預(yù)測(cè)結(jié)果
- 利用教師模型的預(yù)測(cè)結(jié)果計(jì)算蒸餾損失
- 計(jì)算梯度并更新學(xué)生模型的權(quán)重
要對(duì)學(xué)生模型進(jìn)行訓(xùn)練,需要?jiǎng)?chuàng)建一個(gè)訓(xùn)練循環(huán)(training loop)來(lái)遍歷數(shù)據(jù)集,每一步都會(huì)更新學(xué)生模型的權(quán)重,并在每個(gè) epoch 結(jié)束時(shí)打印損失值以監(jiān)測(cè)訓(xùn)練進(jìn)度:


2.6 評(píng)估學(xué)生模型
訓(xùn)練完成后,你可以使用測(cè)試集(x_test 和 y_test)評(píng)估學(xué)生模型的表現(xiàn):

不出所料,學(xué)生模型的準(zhǔn)確率相當(dāng)高:

2.7 使用教師模型和學(xué)生模型進(jìn)行預(yù)測(cè)
現(xiàn)在可以使用教師模型和學(xué)生模型對(duì) MNIST 測(cè)試集的數(shù)字進(jìn)行預(yù)測(cè),觀察兩者的預(yù)測(cè)能力:

前兩個(gè)樣本的預(yù)測(cè)結(jié)果如下:

若測(cè)試更多數(shù)字圖像樣本,你會(huì)發(fā)現(xiàn)學(xué)生模型的表現(xiàn)與教師模型同樣出色。
03 Summary
在本文,我們探討了模型蒸餾(Model Distillation)這一概念,這是一種讓結(jié)構(gòu)更簡(jiǎn)單、規(guī)模更小的學(xué)生模型復(fù)現(xiàn)或逼近結(jié)構(gòu)更復(fù)雜的教師模型的性能的技術(shù)。我們利用 MNIST 數(shù)據(jù)集訓(xùn)練教師模型,然后應(yīng)用模型蒸餾技術(shù)訓(xùn)練學(xué)生模型。最終,層數(shù)更少、結(jié)構(gòu)更精簡(jiǎn)的學(xué)生模型成功復(fù)現(xiàn)了教師模型的性能表現(xiàn),同時(shí)還大大降低了計(jì)算資源的需求。
希望這篇文章能夠滿足各位讀者對(duì)模型蒸餾技術(shù)的好奇心,也希望本文提供的示例代碼可以直觀展現(xiàn)該技術(shù)的高效與實(shí)用。
About the author
Wei-Meng Lee
ACLP Certified Trainer | Blockchain, Smart Contract, Data Analytics, Machine Learning, Deep Learning, and all things tech (??http://calendar.learn2develop.net??).
END
本期互動(dòng)內(nèi)容 ??
?除了模型蒸餾,剪枝和量化也是常用的模型壓縮方法。在你們的項(xiàng)目中,更傾向于采用哪些方法? 歡迎在評(píng)論區(qū)分享~
本文經(jīng)原作者授權(quán),由 Baihai IDP 編譯。如需轉(zhuǎn)載譯文,請(qǐng)聯(lián)系獲取授權(quán)。
原文鏈接:
??https://ai.gopubby.com/understanding-model-distillation-991ec90019b6??

















