跳轉到

3.3 Multi-class Classification

範例程式:Open In Colab

多類別分類是「多個互斥類別選一個」的任務。模型輸入一筆資料後,需要判斷它屬於哪一個類別,例如酒的品種、產品類別、故障類型、客戶分群或文件主題。

本篇使用 scikit-learn 內建的 Wine Dataset 建立表格資料多類別分類模型。Wine Dataset 每筆資料是一款葡萄酒的化學分析結果,目標是判斷它屬於哪一個葡萄酒類別。它是乾淨、小型、可重現的表格資料集,很適合用來練習 DNN 多類別分類流程。

1. 多類別分類在學什麼?

多類別分類的答案只能是其中一個類別。假設任務有三個類別,模型最後會輸出三個機率:

class_0: 0.05
class_1: 0.90
class_2: 0.05

預測時通常選機率最高的類別,也就是 argmax

輸出層通常使用:

tf.keras.layers.Dense(num_classes, activation='softmax')

softmax 會讓所有類別機率加總為 1,因此適合「只能選一類」的互斥分類問題。

2. Wine Dataset 資料說明

Notebook 使用:

data = load_wine(as_frame=True)
df = data.frame

Wine Dataset 共有 178 筆資料、13 個數值特徵與 3 個目標類別。特徵來自葡萄酒的化學分析,例如 alcohol、malic acid、ash、flavanoids、color intensity、proline 等。目標欄位 target 代表葡萄酒類別,數值為 012

這份資料的教學價值在於:

  1. 它是表格資料,不需要處理圖片或文字。
  2. 特徵都是數值欄位,適合示範 StandardScaler
  3. 標籤是整數多類別,適合示範 sparse_categorical_crossentropy
  4. 資料小,Notebook 可以快速執行。

3. 切分資料與標準化

表格資料訓練前通常要切分 train/test,並且只用訓練集 fitting scaler:

x_train, x_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)

stratify=y 會讓 train/test 保持相近的類別比例。fit_transform() 只用在訓練集,測試集只做 transform(),避免測試資料資訊提前進入訓練流程。

4. 標籤格式與 Loss 的對應

多類別分類最常見的錯誤,是標籤格式和 loss function 不一致。

標籤格式 範例 Loss
整數類別 0, 1, 2 sparse_categorical_crossentropy
one-hot [1, 0, 0], [0, 1, 0] categorical_crossentropy

本篇保留 Wine Dataset 的整數標籤,因此模型編譯時使用:

loss='sparse_categorical_crossentropy'

如果你的資料已經把標籤轉成 one-hot,才改用 categorical_crossentropy

5. 建立 DNN 多類別分類模型

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(x_train.shape[1],)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

前兩層 Dense 負責學習特徵組合,最後一層輸出每個類別的機率。num_classes 必須等於目標類別數,在 Wine Dataset 中是 3。

6. 訓練與 EarlyStopping

Notebook 使用 validation split 觀察模型在驗證資料上的表現,並加入 EarlyStopping

early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True
)

restore_best_weights=True 會在訓練結束時恢復 validation loss 最好的模型權重,避免最後幾個 epoch 過擬合或震盪造成評估結果變差。

7. 評估模型

多類別分類除了 accuracy,也應該觀察 confusion matrix 與 classification report:

y_prob = model.predict(x_test, verbose=0)
y_pred = np.argmax(y_prob, axis=1)

print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=data.target_names))

Confusion matrix 可以看出模型最常混淆哪些類別。Classification report 則會列出每個類別的 precision、recall 與 F1-score,適合用來檢查模型是否只對某些類別表現好。

8. 預測新資料

模型輸出的是每個類別的機率:

sample_prob = model.predict(x_test[:5], verbose=0)
predicted_class = np.argmax(sample_prob, axis=1)
confidence = np.max(sample_prob, axis=1)

predicted_class 是預測類別,confidence 是該預測類別的機率。正式專案中通常會同時保存預測類別與機率,方便後續排序、人工覆核或設定低信心警示。

9. 如何套用自己的資料?

df = pd.read_csv('your_data.csv')
target_col = 'target'
feature_cols = [col for col in df.columns if col != target_col]

X = df[feature_cols].values.astype('float32')
y = df[target_col].values.astype('int64')
num_classes = len(np.unique(y))

套用時請確認:

  1. 目標欄位是互斥類別,一筆資料只屬於一類。
  2. 標籤已轉成整數類別,例如 012
  3. num_classes 等於類別數。
  4. 最後一層使用 Dense(num_classes, activation='softmax')
  5. 整數標籤搭配 sparse_categorical_crossentropy

10. 小結

Multi-class Classification 的核心是 softmaxsparse_categorical_crossentropyargmax。Wine Dataset 則提供了一個乾淨的表格資料範例,讓讀者能完整練習資料切分、標準化、DNN 建模、訓練、評估與新資料預測。