8.1 Callback
使用 Keras callbacks 監控訓練、提早停止、儲存最佳模型、調整 learning rate,並保存訓練紀錄。
1. 學習目標
模型訓練不是只呼叫一次 model.fit()。實務上常需要在訓練過程中自動做幾件事:
- validation loss 長時間沒有改善時停止訓練。
- 只保存 validation 表現最好的模型。
- 模型卡住時自動降低 learning rate。
- 把每個 epoch 的 loss 與 accuracy 存成紀錄。
Keras callback 就是放在 model.fit() 裡的訓練控制工具。它不改變模型架構,而是在每個 epoch 或 batch 前後觀察訓練狀態並執行指定動作。
2. 常用 callback
| Callback | 常見用途 | 常見參數 |
|---|---|---|
EarlyStopping |
指標不再改善時停止訓練 | monitor、patience、restore_best_weights |
ModelCheckpoint |
儲存最佳模型或每個 epoch 的模型 | filepath、monitor、save_best_only |
ReduceLROnPlateau |
指標停滯時降低 learning rate | monitor、factor、patience、min_lr |
CSVLogger |
將每個 epoch 的訓練結果寫入 CSV | filename |
最常見的監控指標是 val_loss。它比訓練集 loss 更能反映模型是否真的泛化到未見資料。
3. 標準使用方式
Callback 會先建立成 list,再傳入 model.fit():
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
),
tf.keras.callbacks.ModelCheckpoint(
filepath='best_model.keras',
monitor='val_loss',
save_best_only=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=2,
min_lr=1e-5
),
]
history = model.fit(
x_train,
y_train,
validation_split=0.2,
epochs=30,
callbacks=callbacks
)
4. 本篇範例流程
Notebook 使用 Fashion-MNIST 小型子集建立 CNN 分類模型,示範同一個訓練流程中如何加入:
EarlyStopping:避免訓練過久與過擬合。ModelCheckpoint:保存 validation loss 最低的模型。ReduceLROnPlateau:當 validation loss 停滯時降低 learning rate。CSVLogger:把訓練紀錄輸出成 CSV,方便後續比較。
5. 如何套用到自己的資料?
替換成自己的任務時,callback 大多不需要大改。主要需要確認三件事:
monitor是否存在於history.history,例如val_loss、val_accuracy。patience是否符合資料規模;資料小或訓練波動大時,patience不宜太小。filepath是否放在可寫入的位置;在 Colab 中可改成 Google Drive 路徑。
6. 常見錯誤
- 沒有 validation data,卻監控
val_loss。 patience太小,模型還沒收斂就提早停止。- 只看 training accuracy,沒有觀察 validation loss。
ModelCheckpoint儲存了很多模型檔,卻沒有設定save_best_only=True。
7. 小結
Callback 是 Keras 訓練流程中的自動化控制層。它能讓訓練更穩定、更可追蹤,也能把「何時停止」「何時存模型」「何時降低 learning rate」這些決策明確寫進程式。