3.3 Multi-class Classification
多類別分類是「多個互斥類別選一個」的任務。模型輸入一筆資料後,需要判斷它屬於哪一個類別,例如酒的品種、產品類別、故障類型、客戶分群或文件主題。
本篇使用 scikit-learn 內建的 Wine Dataset 建立表格資料多類別分類模型。Wine Dataset 每筆資料是一款葡萄酒的化學分析結果,目標是判斷它屬於哪一個葡萄酒類別。它是乾淨、小型、可重現的表格資料集,很適合用來練習 DNN 多類別分類流程。
1. 多類別分類在學什麼?
多類別分類的答案只能是其中一個類別。假設任務有三個類別,模型最後會輸出三個機率:
預測時通常選機率最高的類別,也就是 argmax。
輸出層通常使用:
softmax 會讓所有類別機率加總為 1,因此適合「只能選一類」的互斥分類問題。
2. Wine Dataset 資料說明
Notebook 使用:
Wine Dataset 共有 178 筆資料、13 個數值特徵與 3 個目標類別。特徵來自葡萄酒的化學分析,例如 alcohol、malic acid、ash、flavanoids、color intensity、proline 等。目標欄位 target 代表葡萄酒類別,數值為 0、1、2。
這份資料的教學價值在於:
- 它是表格資料,不需要處理圖片或文字。
- 特徵都是數值欄位,適合示範
StandardScaler。 - 標籤是整數多類別,適合示範
sparse_categorical_crossentropy。 - 資料小,Notebook 可以快速執行。
3. 切分資料與標準化
表格資料訓練前通常要切分 train/test,並且只用訓練集 fitting scaler:
x_train, x_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
stratify=y 會讓 train/test 保持相近的類別比例。fit_transform() 只用在訓練集,測試集只做 transform(),避免測試資料資訊提前進入訓練流程。
4. 標籤格式與 Loss 的對應
多類別分類最常見的錯誤,是標籤格式和 loss function 不一致。
| 標籤格式 | 範例 | Loss |
|---|---|---|
| 整數類別 | 0, 1, 2 |
sparse_categorical_crossentropy |
| one-hot | [1, 0, 0], [0, 1, 0] |
categorical_crossentropy |
本篇保留 Wine Dataset 的整數標籤,因此模型編譯時使用:
如果你的資料已經把標籤轉成 one-hot,才改用 categorical_crossentropy。
5. 建立 DNN 多類別分類模型
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(x_train.shape[1],)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(num_classes, activation='softmax')
])
前兩層 Dense 負責學習特徵組合,最後一層輸出每個類別的機率。num_classes 必須等於目標類別數,在 Wine Dataset 中是 3。
6. 訓練與 EarlyStopping
Notebook 使用 validation split 觀察模型在驗證資料上的表現,並加入 EarlyStopping:
early_stop = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=15,
restore_best_weights=True
)
restore_best_weights=True 會在訓練結束時恢復 validation loss 最好的模型權重,避免最後幾個 epoch 過擬合或震盪造成評估結果變差。
7. 評估模型
多類別分類除了 accuracy,也應該觀察 confusion matrix 與 classification report:
y_prob = model.predict(x_test, verbose=0)
y_pred = np.argmax(y_prob, axis=1)
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=data.target_names))
Confusion matrix 可以看出模型最常混淆哪些類別。Classification report 則會列出每個類別的 precision、recall 與 F1-score,適合用來檢查模型是否只對某些類別表現好。
8. 預測新資料
模型輸出的是每個類別的機率:
sample_prob = model.predict(x_test[:5], verbose=0)
predicted_class = np.argmax(sample_prob, axis=1)
confidence = np.max(sample_prob, axis=1)
predicted_class 是預測類別,confidence 是該預測類別的機率。正式專案中通常會同時保存預測類別與機率,方便後續排序、人工覆核或設定低信心警示。
9. 如何套用自己的資料?
df = pd.read_csv('your_data.csv')
target_col = 'target'
feature_cols = [col for col in df.columns if col != target_col]
X = df[feature_cols].values.astype('float32')
y = df[target_col].values.astype('int64')
num_classes = len(np.unique(y))
套用時請確認:
- 目標欄位是互斥類別,一筆資料只屬於一類。
- 標籤已轉成整數類別,例如
0、1、2。 num_classes等於類別數。- 最後一層使用
Dense(num_classes, activation='softmax')。 - 整數標籤搭配
sparse_categorical_crossentropy。
10. 小結
Multi-class Classification 的核心是 softmax、sparse_categorical_crossentropy 與 argmax。Wine Dataset 則提供了一個乾淨的表格資料範例,讓讀者能完整練習資料切分、標準化、DNN 建模、訓練、評估與新資料預測。