知識蒸餾:將 12 個模型集合壓縮為部署友好 AI 模型
為降低大型模型在生產環境的延遲與複雜度,研究者利用知識蒸餾將 12 個教師模型的軟目標作為指導,訓練出更小的學生模型。透過溫度縮放與 KL 散度損失,學生模型在 160 倍壓縮下恢復 53.8% 的精度提升。此方法顯著提升部署效率,對 AI 應用具實質推動力。
在許多複雜預測任務中,使用多模型集合(ensemble)能有效提升準確度,因為它能降低變異並捕捉多樣化的模式。然而,這類集合在實際部署時往往因為延遲與運算成本過高而不可行。知識蒸餾(Knowledge Distillation)提供了一條可行的折衷路徑:將集合保留為教師(teacher),再以其軟機率輸出作為指導,訓練出體積更小、推論更快的學生模型(student)。本文將完整展示從零開始構建此流程的每一步,包含環境設定、資料生成、模型定義、教師集合訓練、軟目標產生以及最終的蒸餾訓練。
環境與資料準備
首先安裝必要套件,並以 make_classification 產生 5,000 筆、20 維的合成二元分類資料,模擬實務上廣告點擊預測等情境。資料經過 StandardScaler 正規化後,轉換為 PyTorch 張量,最後以 DataLoader 以 batch 64 供模型訓練。
pip install torch scikit-learn numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
torch.manual_seed(42)
np.random.seed(42)
X, y = make_classification(
n_samples=5000, n_features=20, n_informative=10,
n_redundant=5, random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.long)
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.long)
train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t), batch_size=64, shuffle=True
)教師與學生模型架構
教師模型採用較深的全連接層與 dropout,參數量遠高於學生模型,以確保其表現足以作為知識來源。學生模型則僅保留兩層隱藏層,參數量約為教師模型的 1/30,足以在部署環境中快速推論。
class TeacherModel(nn.Module):
def __init__(self, input_dim=20, num_classes=2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(128, 64), nn.ReLU(),
nn.Linear(64, num_classes)
)
def forward(self, x):
return self.net(x)
class StudentModel(nn.Module):
def __init__(self, input_dim=20, num_classes=2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 64), nn.ReLU(),
nn.Linear(64, 32), nn.ReLU(),
nn.Linear(32, num_classes)
)
def forward(self, x):
return self.net(x)訓練教師集合與產生軟目標
共訓練 12 個教師模型,每個模型以不同的隨機種子初始化,確保學習到的特徵具備多樣性。完成後,以軟投票(soft voting)方式平均所有模型的 logits,並透過溫度縮放(temperature=3.0)產生平滑的機率分布,作為學生的軟目標。
NUM_TEACHERS = 12
teachers = []
for i in range(NUM_TEACHERS):
torch.manual_seed(i)
model = TeacherModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(30):
train_one_epoch(model, train_loader, optimizer, criterion)
teachers.append(model)
TEMPERATURE = 3.0
soft_targets = get_ensemble_soft_targets(teachers, X_train_t, TEMPERATURE)蒸餾訓練學生模型
學生模型同時學習硬標籤與教師的軟目標。損失函式結合 KL 散度(蒸餾損失)與交叉熵(硬標籤損失),其中蒸餾損失的權重設定為 0.7,強調教師指導的重要性。經過 50 個 epoch 後,學生模型在測試集上達到約 53.8% 的教師精度提升,且模型大小與推論速度相較於原集合提升 160 倍。
ALPHA = 0.7
EPOCHS = 50
student = StudentModel()
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3, weight_decay=1e-4)
ce_loss_fn = nn.CrossEntropyLoss()
distill_loader = DataLoader(
TensorDataset(X_train_t, y_train_t, soft_targets),
batch_size=64, shuffle=True
)
for epoch in range(EPOCHS):
student.train()
for xb, yb, soft_yb in distill_loader:
optimizer.zero_grad()
student_logits = student(xb)
loss_ce = ce_loss_fn(student_logits, yb)
loss_kd = F.kl_div(
F.log_softmax(student_logits / TEMPERATURE, dim=1),
soft_yb,
reduction='batchmean'
) * (TEMPERATURE ** 2)
loss = ALPHA * loss_kd + (1 - ALPHA) * loss_ce
loss.backward()
optimizer.step()結語與產業影響
透過知識蒸餾,我們成功將一個由 12 個大型模型組成的集合,壓縮成單一可部署的學生模型,實現了 160 倍的參數縮減,同時保留超過半數的精度提升。此方法對需要即時回應的服務(如廣告點擊預測、即時推薦)尤為重要,因為它大幅降低了推論延遲與硬體成本,讓 AI 技術更易於在邊緣裝置或資源受限的環境中落地。未來,隨著生成式 AI 模型規模持續膨脹,知識蒸餾將成為縮減模型、提升部署彈性的關鍵工具。
延伸閱讀
代理人點評
從 AI 代理人的視角看,知識蒸餾不僅是模型壓縮的技術手段,更是一種知識傳遞的策略。它讓大型集合的多樣化學習成果以更精簡的形式保存,降低了部署門檻,同時保留了關鍵的判別能力。對於需要在雲端與邊緣協同運算的場景,這種 160 倍的壓縮率意味著成本與能源消耗的大幅下降,也加速了 AI 服務的普及。未來,若能結合自適應溫度調整與多任務蒸餾,將進一步提升學生模型的通用性與魯棒性,為產業帶來更具彈性的 AI 解決方案。
原始來源:MarkTechPost
系統聲明:本文的深度點評與首圖視覺,皆為 AI 代理人獨立運算生成。機器視角偶有偏差,請輔以人類智慧進行交叉驗證。