dense-mnist-tf/train.py

149 lines
4.9 KiB
Python

import argparse
import gzip
import json
import os
import numpy as np
import tensorflow as tf
num_workers = 1
def load_data(data_dir):
files = [
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
]
paths = []
for f in files:
paths.append(os.path.join(data_dir, f))
with gzip.open(paths[0], 'rb') as f:
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as f:
x_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
with gzip.open(paths[2], 'rb') as f:
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as f:
x_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
return (x_train, y_train), (x_test, y_test)
def mnist_dataset(batch_size=64, data_dir=None):
# load dataset
if data_dir:
print(f'Loading mnist data from {data_dir}')
(x_train, y_train), (x_test, y_test) = load_data(data_dir)
else:
print('Loading mnist data from tf.keras.datasets.mnist')
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
return train_dataset, test_dataset
def get_strategy(strategy='off'):
strategy = strategy.lower()
# multiple nodes, every nodes have multiple GPUs
if strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy()
# single node with multiple GPUs
if strategy == "mirrored":
return tf.distribute.MirroredStrategy()
# single node with single GPU
return tf.distribute.get_strategy()
def setup_env(args):
tf.config.set_soft_device_placement(True)
# limit the gpu memory usage as much as it need.
try:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.list_logical_devices('GPU')
print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
except RuntimeError as e:
print(e)
if args.strategy == 'multi_worker_mirrored':
index = int(os.environ['VK_TASK_INDEX'])
task_name = os.environ["VC_TASK_NAME"].upper()
ips = os.environ[f'VC_{task_name}_HOSTS']
ips = ips.split(',')
global num_workers
num_workers = len(ips)
ips = [f'{ip}:20000' for ip in ips]
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ips
},
"task": {"type": "worker", "index": index}
})
def setup_config():
parser = argparse.ArgumentParser(description='Train MNIST digits classification')
parser.add_argument('--data-dir', help='Directory to MNIST dataset')
parser.add_argument('--output-dir', help='Directory to save models and logs')
parser.add_argument('--epochs', default=2, help='Number of epochs')
parser.add_argument('--eval', action='store_true', help='whether do evaluation after training finished')
parser.add_argument(
'--strategy',
default='off',
choices=['off', 'mirrored', 'multi_worker_mirrored'],
help='TensorFlow distributed training strategies'
)
args = parser.parse_args()
return args
def main():
args = setup_config()
# tf2 limitation: Collective ops must be configured at program startup
strategy = get_strategy(args.strategy)
setup_env(args)
with strategy.scope():
# build dataset
train_dataset, test_dataset = mnist_dataset(
batch_size=64 * num_workers,
data_dir=args.data_dir
)
# build model
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
model.summary()
# training
model.fit(train_dataset, epochs=args.epochs, steps_per_epoch=70)
# evaluation
if args.eval:
model.evaluate(test_dataset, verbose=2)
# save model
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
model.save(args.output_dir)
print(f'Saved model to {args.output_dir}')
if __name__ == '__main__':
# python train.py --data-dir=/path/to/MNIST/dataset --output-dir=/path/to/output
print("TensorFlow version:", tf.__version__)
main()