12 - 乜嘢係「梯度消失」同「梯度爆炸」?
訓練深度神經網絡,有時會見到兩種極端:
梯度消失 (vanishing gradients)
‣ 反向傳播時,梯度細到近 0 → 權重幾乎唔郁 → 成個網絡好似冇學過嘢。
梯度爆炸 (exploding gradients)
‣ 梯度大到癲 → 權重更新過火 → Loss 發散 / 出 NaN → 模型亂飛。
呢兩隻「洪水猛獸」通常喺 層數多、序列長 嘅網絡特別易出現(RNN、Transformer、深 CNN 皆中招)。
🧠 點解會發生?——層層相乘嘅雪崩效應
反向傳播本質:
∂Loss/∂W_l = ∂Loss/∂a_L × ∂a_L/∂a_{L-1} × … × ∂a_{l+1}/∂a_l × ∂a_l/∂W_l
每一層都乘一次 激活函數斜率 同 權重。
斜率 < 1(例如 sigmoid 最大 0.25):乘落去就愈來愈細 → 梯度消失
斜率 > 1 或權重偏大:乘落去就愈來愈大 → 梯度爆炸
可以想像做複利運算:
連續乘 0.5 十次只剩 0.1% (錢變空氣)。
連續乘 1.5 十次變 57 倍(泡沫爆煲)。
📚 更多貼地比喻
現象 | 生活版比喻 | 結果 |
梯度消失 | 「傳聲筒」每個人細聲 80%,傳 10 個就聽唔到 | 網絡學唔到長期依賴,RNN 忘記前文 |
梯度爆炸 | 八卦新聞人人加鹽加醋 20%,傳幾個已經世界末日 | 參數爆大,訓練 Loss 變 Inf |
🌍 真實案例
長句子機器翻譯
原始 RNN 10–20 步序列就忘記開頭主語 → 梯度早已消失。
GAN 生圖
Generator 梯度爆炸 → 成張圖鋸齒、花屏。
BERT 微調
學習率手滑設 5e-4,前幾層梯度爆炸,訓練兩分鐘全 nan。
自駕車感知
152 層 ResNet 如初始化錯,梯度回傳時已經 0,學唔識夜間行人。
🔍 點樣察覺自己中招?
Loss 曲線
直線水平或緩慢下降 → 可能消失
突然向上或變 NaN → 可能爆炸
權重/梯度直方圖
全貼近 0 → 消失
分佈拉得好闊、尾巴長 → 爆炸
log 梯度範數
torch.nn.utils.clip_grad_norm_ 之前 print,一路細於 1e-6 或大過 1e+3 都係警號。
🛠️ 五大解藥與例子
權重初始化對症下藥
Xavier/Glorot → sigmoid、tanh 網絡
He 初始化 → ReLU、LeakyReLU
例:用 He normal 初始化 ResNet-50,梯度範數由 1e-9 提升到 1e-2,成功收斂。
激活函數升級
換 sigmoid ➜ ReLU / LeakyReLU / GELU
例:情感分析 LSTM 用 tanh → ReLU + LayerNorm,F1 多 5%。
Batch / Layer / RMS Normalization
每個 mini-batch 把均值調 0、方差調 1,控制信號範圍。
例:加入 BatchNorm 後,ImageNet CNN 允許把 LR 由 0.01 加到 0.1。
跳接 (Skip Connections)
ResNet、Transformer 都有 x + F(x),令梯度可以繞過多層直達前面。
例:ResNet-152 如果拆走 skip,top-1 準確率跌 15%。
梯度裁剪 (Gradient Clipping)
clip_grad_norm_ 或 clip_grad_value_ 把範數限制喺 1–5。
例:機械臂強化學習,無 clipping reward 震盪;加 clipping 後 30 分鐘學識抓取。
其他輔助:小批次、學習率預熱 (Warm-up)、AdamW (帶 weight decay)、混合精度避免 overflow。
🏗️ PyTorch 示範片段
model.apply(lambda m: isinstance(m, nn.Conv2d) and
nn.init.kaiming_normal_(m.weight, mode='fan_out'))
for step, (x, y) in enumerate(loader):
out = model(x)
loss = criterion(out, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
optimizer.step(); optimizer.zero_grad()
✅ 小結
梯度消失 = 訊息被層層稀釋,模型「冇得學」。
梯度爆炸 = 訊息被層層放大,模型「學到癲」。
應對策略:
• 合適初始化 → Xavier / He
• 穩定激活 → ReLU 家族
• 正規化 → BatchNorm / LayerNorm
• 結構改良 → Skip connection
• 實戰保險 → Gradient clipping
記住:深度 ≠ 盲目堆砌,必須確保梯度可以 健康地跑完全程接力賽。搞掂呢關,你嘅網絡先有資格向更深更複雜挑戰!🚀