13 - 梯度檢查係乜?點解要做?
諗下你用 LEGO 砌十層高樓,每層都交畀自己寫程式嘅機械人去擺積木。你唔肯定程式會唔會擺多粒、調轉方向,砌到第六層先發現歪咗就 GG。梯度檢查 (Gradient Checking) 就係請個「監工」:逐粒積木量度,確保反向傳播嘅微分公式寫得啱,避免訓練途中樓歪晒。
🔬 梯度檢查原理=數值微分做驗算
揀一個參數 θi\theta_iθi。
前向傳播算 Loss:
J+=J(θi+ε)J^{+} = J(\theta_i + \varepsilon)J+=J(θi+ε)
J−=J(θi−ε)J^{-} = J(\theta_i - \varepsilon)J−=J(θi−ε)
數值梯度:
∂J∂θi≈J+−J−2ε\frac{\partial J}{\partial \theta_i} \approx \frac{J^{+}-J^{-}}{2\varepsilon}∂θi∂J≈2εJ+−J−
同你寫嘅後向梯度 gbackpropg_{\text{backprop}}gbackprop 比對:
相差 < 10−710^{-7}10−7:OK
相差大:公式或維度有蟲
因為 每個參數都要算兩次 Loss,上萬參數就要跑幾萬次 forward,所以淨係 除錯階段 用,訓練時千祈唔好開。
🌍 真實踩過界例子
項目 | Bug | 梯度檢查點救命 |
GAN 生成漫畫 | 自己手寫 tanh 導數漏咗 1- | 數值 vs. 手算差 0.2,立刻對準公式 |
強化學習 DQN | Target network Loss 忘記 stop-grad | 檢查到最後層梯度巨大,定位到 target 分支 |
圖神經網絡 GAT | Attention softmax backward 抄錯維度 | 有一條邊梯度爆 NaN,逐層比對揪出 |
金融時序 Transformer | 加 L2 正則化但 Loss 冇加該項 | 數值梯度全細 1e-3,提醒忘嘢 |
醫療影像 3D CNN | 訓練用 Dropout=0.2,檢查時忘關 | 前向每次 Loss 漂浮,差異無法收斂 → 關掉即穩定 |
⚠️ 四大常見陷阱同拆招
1. 數值誤差大到嚇親
逐層打印差異,先鎖定「漏水位」。
檢查轉置 W.T、broadcast 次序、reshape。
2. 忘記將正則化寫入 Loss
用 L2、Label-smoothing、Weight decay 都要加返同一項去 J^+ / J^-。
3. Dropout / BatchNorm 未鎖死
臨時把 keep_prob=1、training=False,確保 determinism。
4. 初期 ok,之後炸鍋
先訓練 100–500 step 令權重變大,再跑梯度檢查,較易現形。
🛠️ 推薦實戰流程(PyTorch 示意)
eps = 1e-7
model.eval() # 關 BatchNorm / Dropout
for p in model.parameters():
p.requires_grad = True
# 隨機挑 20 個參數做檢查
for name, param in list(model.named_parameters())[:20]:
# 手算梯度
model.zero_grad()
loss = criterion(model(x), y)
loss.backward()
grad_back = param.grad.clone()
# 數值梯度
perturb = torch.zeros_like(param)
idx = torch.randint(0, param.nelement(), (1,))
perturb.view(-1)[idx] = eps
loss_plus = criterion(model(x, extra_param=param+perturb), y)
loss_minus = criterion(model(x, extra_param=param-perturb), y)
grad_num = (loss_plus - loss_minus) / (2*eps)
diff = torch.norm(grad_back.view(-1)[idx] - grad_num) / \
(torch.norm(grad_back.view(-1)[idx]) + torch.norm(grad_num))
print(f"{name}[{idx.item()}] relative diff: {diff:.2e}")
🚗 一條龍「監工」Checklist
關掉隨機因素:Dropout、data shuffle 固定 seed。
將正則化全部納入 Loss。
用細 batch、細 network,快啲跑完。
相對誤差 < 10−710^{-7}10−7 即合格,> 10−410^{-4}10−4 基本有蟲。
揪出 Bug 修正,再用真網絡/大 batch 開 full training。