跳轉到

8.1 Callback

範例程式:Open In Colab

使用 Keras callbacks 監控訓練、提早停止、儲存最佳模型、調整 learning rate,並保存訓練紀錄。

1. 學習目標

模型訓練不是只呼叫一次 model.fit()。實務上常需要在訓練過程中自動做幾件事:

  1. validation loss 長時間沒有改善時停止訓練。
  2. 只保存 validation 表現最好的模型。
  3. 模型卡住時自動降低 learning rate。
  4. 把每個 epoch 的 loss 與 accuracy 存成紀錄。

Keras callback 就是放在 model.fit() 裡的訓練控制工具。它不改變模型架構,而是在每個 epoch 或 batch 前後觀察訓練狀態並執行指定動作。

2. 常用 callback

Callback 常見用途 常見參數
EarlyStopping 指標不再改善時停止訓練 monitorpatiencerestore_best_weights
ModelCheckpoint 儲存最佳模型或每個 epoch 的模型 filepathmonitorsave_best_only
ReduceLROnPlateau 指標停滯時降低 learning rate monitorfactorpatiencemin_lr
CSVLogger 將每個 epoch 的訓練結果寫入 CSV filename

最常見的監控指標是 val_loss。它比訓練集 loss 更能反映模型是否真的泛化到未見資料。

3. 標準使用方式

Callback 會先建立成 list,再傳入 model.fit()

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=3,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath='best_model.keras',
        monitor='val_loss',
        save_best_only=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=2,
        min_lr=1e-5
    ),
]

history = model.fit(
    x_train,
    y_train,
    validation_split=0.2,
    epochs=30,
    callbacks=callbacks
)

4. 本篇範例流程

Notebook 使用 Fashion-MNIST 小型子集建立 CNN 分類模型,示範同一個訓練流程中如何加入:

  1. EarlyStopping:避免訓練過久與過擬合。
  2. ModelCheckpoint:保存 validation loss 最低的模型。
  3. ReduceLROnPlateau:當 validation loss 停滯時降低 learning rate。
  4. CSVLogger:把訓練紀錄輸出成 CSV,方便後續比較。

5. 如何套用到自己的資料?

替換成自己的任務時,callback 大多不需要大改。主要需要確認三件事:

  1. monitor 是否存在於 history.history,例如 val_lossval_accuracy
  2. patience 是否符合資料規模;資料小或訓練波動大時,patience 不宜太小。
  3. filepath 是否放在可寫入的位置;在 Colab 中可改成 Google Drive 路徑。

6. 常見錯誤

  • 沒有 validation data,卻監控 val_loss
  • patience 太小,模型還沒收斂就提早停止。
  • 只看 training accuracy,沒有觀察 validation loss。
  • ModelCheckpoint 儲存了很多模型檔,卻沒有設定 save_best_only=True

7. 小結

Callback 是 Keras 訓練流程中的自動化控制層。它能讓訓練更穩定、更可追蹤,也能把「何時停止」「何時存模型」「何時降低 learning rate」這些決策明確寫進程式。