top of page

13 - 梯度檢查係乜?點解要做?


諗下你用 LEGO 砌十層高樓,每層都交畀自己寫程式嘅機械人去擺積木。你唔肯定程式會唔會擺多粒、調轉方向,砌到第六層先發現歪咗就 GG。梯度檢查 (Gradient Checking) 就係請個「監工」:逐粒積木量度,確保反向傳播嘅微分公式寫得啱,避免訓練途中樓歪晒。


🔬 梯度檢查原理=數值微分做驗算


  1. 揀一個參數 θi\theta_iθi​。

  2. 前向傳播算 Loss:

    • J+=J(θi+ε)J^{+} = J(\theta_i + \varepsilon)J+=J(θi​+ε)

    • J−=J(θi−ε)J^{-} = J(\theta_i - \varepsilon)J−=J(θi​−ε)

  3. 數值梯度:

    ∂J∂θi≈J+−J−2ε\frac{\partial J}{\partial \theta_i} \approx \frac{J^{+}-J^{-}}{2\varepsilon}∂θi​∂J​≈2εJ+−J−​

  4. 同你寫嘅後向梯度 gbackpropg_{\text{backprop}}gbackprop​ 比對:

    • 相差 < 10−710^{-7}10−7:OK

    • 相差大:公式或維度有蟲

因為 每個參數都要算兩次 Loss,上萬參數就要跑幾萬次 forward,所以淨係 除錯階段 用,訓練時千祈唔好開。


🌍 真實踩過界例子


項目

Bug

梯度檢查點救命

GAN 生成漫畫

自己手寫 tanh 導數漏咗 1-

數值 vs. 手算差 0.2,立刻對準公式

強化學習 DQN

Target network Loss 忘記 stop-grad

檢查到最後層梯度巨大,定位到 target 分支

圖神經網絡 GAT

Attention softmax backward 抄錯維度

有一條邊梯度爆 NaN,逐層比對揪出

金融時序 Transformer

加 L2 正則化但 Loss 冇加該項

數值梯度全細 1e-3,提醒忘嘢

醫療影像 3D CNN

訓練用 Dropout=0.2,檢查時忘關

前向每次 Loss 漂浮,差異無法收斂 → 關掉即穩定


⚠️ 四大常見陷阱同拆招


1. 數值誤差大到嚇親

  • 逐層打印差異,先鎖定「漏水位」。

  • 檢查轉置 W.T、broadcast 次序、reshape。

2. 忘記將正則化寫入 Loss

  • 用 L2、Label-smoothing、Weight decay 都要加返同一項去 J^+ / J^-。

3. Dropout / BatchNorm 未鎖死

  • 臨時把 keep_prob=1、training=False,確保 determinism。

4. 初期 ok,之後炸鍋

  • 先訓練 100–500 step 令權重變大,再跑梯度檢查,較易現形。


🛠️ 推薦實戰流程(PyTorch 示意)

eps = 1e-7
model.eval()           # 關 BatchNorm / Dropout
for p in model.parameters():
    p.requires_grad = True

# 隨機挑 20 個參數做檢查
for name, param in list(model.named_parameters())[:20]:
    # 手算梯度
    model.zero_grad()
    loss = criterion(model(x), y)
    loss.backward()
    grad_back = param.grad.clone()

    # 數值梯度
    perturb = torch.zeros_like(param)
    idx = torch.randint(0, param.nelement(), (1,))
    perturb.view(-1)[idx] = eps
    loss_plus  = criterion(model(x, extra_param=param+perturb), y)
    loss_minus = criterion(model(x, extra_param=param-perturb), y)
    grad_num = (loss_plus - loss_minus) / (2*eps)

    diff = torch.norm(grad_back.view(-1)[idx] - grad_num) / \
           (torch.norm(grad_back.view(-1)[idx]) + torch.norm(grad_num))
    print(f"{name}[{idx.item()}] relative diff: {diff:.2e}")

🚗 一條龍「監工」Checklist


  1. 關掉隨機因素:Dropout、data shuffle 固定 seed。

  2. 將正則化全部納入 Loss。

  3. 用細 batch、細 network,快啲跑完。

  4. 相對誤差 < 10−710^{-7}10−7 即合格,> 10−410^{-4}10−4 基本有蟲。

  5. 揪出 Bug 修正,再用真網絡/大 batch 開 full training。


✅ Take-away


  • 梯度檢查唔係訓練招式,係 Debug 專用放大鏡。

  • 一定要 deterministic:Dropout off、BatchNorm eval。

  • 差異大 → 先 lock 目標層,再逐行 check 公式同維度。

  • 訓練前、中、後各跑一次,可捉早期 vs. 後期隱藏錯誤。

做足監工,先至可以放心叫機械人一路砌到 100 層,唔怕中途樓塌!🚀

翱翔醫療 (2).png

Tsim Sha Tsui H Zentre Clinic

Suite 813, 8/F, H Zentre

15 Middle Road, TST

Phone: 28133700

​Whatsapp:+852 95096276

Central Printing House Clinic

Room 303A & 305,

3/F, Printing House,

6 Duddell Street, Central

Phone: 28716733 / 28716788

Whatsapp:+852 62084539

TKO Maritime Bay Clinic

UG18, UG/F,

Maritime Bay Shopping Centre
Hang Hau, Tseung Kwan O
Tel: 98852916; Whatsapp: 98852916

​Phone:98852916

Whatsapp:+852 98852916

Mong Kok T.O.P. Clinic

Room 2001, 20/F,
700 Nathan Road, Mong Kok

​(Going above from the the 3/F elevator of T.O.P. Mall)

Phone:28710277

Whatsapp:+852 98893911

bottom of page