跳轉到

1.5 Compile、Fit、Evaluate、Predict

範例程式:Open In Colab

Keras 的訓練流程可以濃縮成四個方法:compile()fit()evaluate()predict()。幾乎所有 supervised learning 任務都會反覆使用這四個步驟。

1. 學習目標

這篇使用 Wine 多類別分類資料集,示範模型從建立後如何編譯、訓練、評估與預測。Wine Dataset 每筆資料是一款葡萄酒的化學分析數值,目標是判斷它屬於哪一個酒品類別。這份資料乾淨、類別明確、特徵數適中,適合用來練習 Keras 的標準訓練流程。

本篇重點不是追求複雜模型,而是建立可重複套用的四步流程:先用 compile() 定義訓練規則,再用 fit() 訓練模型,接著用 evaluate() 檢查整體表現,最後用 predict() 取得逐筆預測結果。

2. 四個方法各自做什麼?

方法 用途
compile() 指定 optimizer、loss 與 metrics
fit() 使用訓練資料更新模型權重
evaluate() 在訓練集、驗證集或測試集上計算指標
predict() 對新資料輸出模型預測結果

3. compile 的設定邏輯

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

多類別分類若標籤是整數類別,例如 Wine Dataset 中的 0, 1, 2,常用 sparse_categorical_crossentropy。若標籤已經 one-hot encoding,則常用 categorical_crossentropy

compile() 是把任務型態轉成訓練設定的地方。輸出層、loss 與 metrics 必須互相對齊,否則模型即使能跑,也可能不是在解正確的問題。

4. fit 的回傳值很重要

model.fit() 會回傳 History 物件,裡面保存每個 epoch 的 loss、accuracy、validation loss 與 validation accuracy。這些紀錄可以用來畫 learning curve,觀察模型是否收斂或過擬合。

5. evaluate 與 predict 的差異

evaluate() 用來算整體分數,predict() 用來取得每筆資料的預測結果。實務上通常會先用 evaluate() 看整體表現,再用 predict() 分析錯誤案例。

在本篇 Wine 多類別分類範例中,predict() 會輸出每一類的機率,接著用 argmax 取得預測類別。若只看 evaluate() 的 accuracy,會少掉逐筆預測與錯誤分析的資訊。

6. 如何套用自己的資料?

只要把資料整理成:

x_train, y_train
x_test, y_test

接著確認輸出層與 loss 符合任務,就能套用相同流程。

例如:

任務 輸出層 Loss
回歸 Dense(1) mse
二元分類 Dense(1, activation='sigmoid') binary_crossentropy
多類別分類,整數標籤 Dense(num_classes, activation='softmax') sparse_categorical_crossentropy
多類別分類,one-hot 標籤 Dense(num_classes, activation='softmax') categorical_crossentropy

7. 小結

compile()fit()evaluate()predict() 是 Keras supervised learning 的核心四步。後續不論是 DNN、CNN、RNN 或 Transformer,都會圍繞這個流程展開。