2.7 tf.data 效能優化
當模型訓練變慢時,問題不一定在模型,也可能是資料管線供應太慢。本篇示範 cache()、prefetch() 與 tf.data.AUTOTUNE 的基本使用方式,目標是讓模型訓練時不要一直等待下一批資料。
1. 學習目標
訓練神經網路時,GPU 或 CPU 需要不斷取得下一批資料。如果資料讀取與前處理跟不上,硬體會等待資料,整體訓練就會變慢。
2. 常見效能技巧
| 技巧 | 作用 | 注意事項 |
|---|---|---|
cache() |
快取前處理後的資料 | 小資料可放記憶體,大資料可指定檔案 |
prefetch() |
訓練時預先準備下一批資料 | 通常搭配 AUTOTUNE |
num_parallel_calls |
平行執行 map | 適合較重的前處理 |
AUTOTUNE |
讓 TensorFlow 自動調整平行度 | 多數情境可先使用 |
3. 實作流程
Notebook 使用 synthetic dataset 示範基本資料管線與優化資料管線的差異,並用同一個 Keras 模型確認 optimized dataset 可以直接訓練。範例會把穩定的前處理結果快取起來,再視訓練需求 shuffle、batch 與 prefetch。
4. 如何套用自己的資料?
若資料前處理成本高,可將 map(..., num_parallel_calls=tf.data.AUTOTUNE) 放在 batch 前,最後加上 prefetch(tf.data.AUTOTUNE)。若資料可放進記憶體,可考慮 cache()。訓練資料若還需要每個 epoch 重新打亂,通常會把穩定前處理 cache 起來後再 shuffle。
5. 小結
tf.data 效能優化的目標,是讓模型訓練時資料能穩定供應。正式專案可先從 cache、prefetch、AUTOTUNE 三個設定開始。