1.5 Compile、Fit、Evaluate、Predict
Keras 的訓練流程可以濃縮成四個方法:compile()、fit()、evaluate()、predict()。幾乎所有 supervised learning 任務都會反覆使用這四個步驟。
1. 學習目標
這篇使用 Wine 多類別分類資料集,示範模型從建立後如何編譯、訓練、評估與預測。Wine Dataset 每筆資料是一款葡萄酒的化學分析數值,目標是判斷它屬於哪一個酒品類別。這份資料乾淨、類別明確、特徵數適中,適合用來練習 Keras 的標準訓練流程。
本篇重點不是追求複雜模型,而是建立可重複套用的四步流程:先用 compile() 定義訓練規則,再用 fit() 訓練模型,接著用 evaluate() 檢查整體表現,最後用 predict() 取得逐筆預測結果。
2. 四個方法各自做什麼?
| 方法 | 用途 |
|---|---|
compile() |
指定 optimizer、loss 與 metrics |
fit() |
使用訓練資料更新模型權重 |
evaluate() |
在訓練集、驗證集或測試集上計算指標 |
predict() |
對新資料輸出模型預測結果 |
3. compile 的設定邏輯
多類別分類若標籤是整數類別,例如 Wine Dataset 中的 0, 1, 2,常用 sparse_categorical_crossentropy。若標籤已經 one-hot encoding,則常用 categorical_crossentropy。
compile() 是把任務型態轉成訓練設定的地方。輸出層、loss 與 metrics 必須互相對齊,否則模型即使能跑,也可能不是在解正確的問題。
4. fit 的回傳值很重要
model.fit() 會回傳 History 物件,裡面保存每個 epoch 的 loss、accuracy、validation loss 與 validation accuracy。這些紀錄可以用來畫 learning curve,觀察模型是否收斂或過擬合。
5. evaluate 與 predict 的差異
evaluate() 用來算整體分數,predict() 用來取得每筆資料的預測結果。實務上通常會先用 evaluate() 看整體表現,再用 predict() 分析錯誤案例。
在本篇 Wine 多類別分類範例中,predict() 會輸出每一類的機率,接著用 argmax 取得預測類別。若只看 evaluate() 的 accuracy,會少掉逐筆預測與錯誤分析的資訊。
6. 如何套用自己的資料?
只要把資料整理成:
接著確認輸出層與 loss 符合任務,就能套用相同流程。
例如:
| 任務 | 輸出層 | Loss |
|---|---|---|
| 回歸 | Dense(1) |
mse |
| 二元分類 | Dense(1, activation='sigmoid') |
binary_crossentropy |
| 多類別分類,整數標籤 | Dense(num_classes, activation='softmax') |
sparse_categorical_crossentropy |
| 多類別分類,one-hot 標籤 | Dense(num_classes, activation='softmax') |
categorical_crossentropy |
7. 小結
compile()、fit()、evaluate()、predict() 是 Keras supervised learning 的核心四步。後續不論是 DNN、CNN、RNN 或 Transformer,都會圍繞這個流程展開。