keras分布式訓練
先來個簡單的分布式訓練
keras分布式訓練
#導入依賴?
from __future__ import absolute_import, division, print_function, unicode_literals
# 導入 TensorFlow 和 TensorFlow 數(shù)據(jù)集
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
import os
如果把第一行屏蔽會有什么效果?用沒有影響
2020-05-28 21:34:05.468570: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-05-28 21:34:05.497031: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fe6ff957c10 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-05-28 21:34:05.497053: I tensorflow/compiler/xla/service/service.cc:176]? ?StreamExecutor device (0): Host, Default Version
Number of devices: 1
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
Epoch 1/12
? ? ? 1/Unknown - 2s 2s/step - loss: 2.3115 - accuracy: 0.14062020-05-28 21:34:08.723148: I tensorflow/core/profiler/lib/profiler_session.cc:225] Profiler session started.
下載數(shù)據(jù)集
下載 MNIST 數(shù)據(jù)集并從?TensorFlow Datasets?加載。 這會返回?tf.data
?格式的數(shù)據(jù)集。
將?with_info
?設置為?True
?會包含整個數(shù)據(jù)集的元數(shù)據(jù),其中這些數(shù)據(jù)集將保存在?info
?中。 除此之外,該元數(shù)據(jù)對象包括訓練和測試示例的數(shù)量。?
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
定義分配策略
創(chuàng)建一個?MirroredStrategy
?對象。這將處理分配策略,并提供一個上下文管理器(tf.distribute.MirroredStrategy.scope
)來構建你的模型。
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
設置輸入管道(pipeline)
在訓練具有多個 GPU 的模型時,您可以通過增加批量大小(batch size)來有效地使用額外的計算能力。通常來說,使用適合 GPU 內存的最大批量大?。╞atch size),并相應地調整學習速率。
0-255 的像素值,?必須標準化到 0-1 范圍。在函數(shù)中定義標準化。
生成模型
在?strategy.scope
?的上下文中創(chuàng)建和編譯 Keras 模型。
with strategy.scope():
? model = tf.keras.Sequential([
? ? ? tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
? ? ? tf.keras.layers.MaxPooling2D(),
? ? ? tf.keras.layers.Flatten(),
? ? ? tf.keras.layers.Dense(64, activation='relu'),
? ? ? tf.keras.layers.Dense(10, activation='softmax')
? ])
? model.compile(loss='sparse_categorical_crossentropy',
? ? ? ? ? ? ? ? optimizer=tf.keras.optimizers.Adam(),
? ? ? ? ? ? ? ? metrics=['accuracy'])
訓練和評估
在該部分,以普通的方式訓練模型,在模型上調用?fit
?并傳入在教程開始時創(chuàng)建的數(shù)據(jù)集。 無論您是否分布式訓練,此步驟都是相同的。
model.fit(train_dataset, epochs=12, callbacks=callbacks)
要查看模型的執(zhí)行方式,請加載最新的檢查點(checkpoint)并在測試數(shù)據(jù)上調用?evaluate
?。
使用適當?shù)臄?shù)據(jù)集調用?evaluate
?。
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
導出到 SavedModel
將圖形和變量導出為與平臺無關的 SavedModel 格式。 保存模型后,可以在有或沒有 scope 的情況下加載模型。
在無需?strategy.scope
?加載模型
通過load_from_saved_model(載入上面保存的模型)
運行結果,精度為99
運行效果部分截取
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 23s 25ms/step - loss: 0.2050 - accuracy: 0.9399
Epoch 2/12
937/938 [============================>.] - ETA: 0s - loss: 0.0661 - accuracy: 0.9804
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 17s 18ms/step - loss: 0.0661 - accuracy: 0.9804
Epoch 3/12
936/938 [============================>.] - ETA: 0s - loss: 0.0464 - accuracy: 0.9862
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 14s 15ms/step - loss: 0.0464 - accuracy: 0.9862
Epoch 4/12
936/938 [============================>.] - ETA: 0s - loss: 0.0266 - accuracy: 0.9927
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 18s 19ms/step - loss: 0.0266 - accuracy: 0.9927
Epoch 5/12
934/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9939
Learning rate for epoch 5 is 9.999999747378752e-05
......
938/938 [==============================] - 17s 18ms/step - loss: 0.0151 - accuracy: 0.9966