上传数字识别使用的部署代码及模型
This commit is contained in:
parent
8cb1c1df70
commit
6c51d5b254
|
@ -0,0 +1,60 @@
|
||||||
|
import argparse
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(data_dir):
|
||||||
|
files = [
|
||||||
|
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
|
||||||
|
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
|
||||||
|
]
|
||||||
|
paths = []
|
||||||
|
for f in files:
|
||||||
|
paths.append(os.path.join(data_dir, f))
|
||||||
|
with gzip.open(paths[0], 'rb') as f:
|
||||||
|
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||||
|
with gzip.open(paths[1], 'rb') as f:
|
||||||
|
x_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
|
||||||
|
with gzip.open(paths[2], 'rb') as f:
|
||||||
|
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||||
|
with gzip.open(paths[3], 'rb') as f:
|
||||||
|
x_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
|
||||||
|
return (x_train, y_train), (x_test, y_test)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_config():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="MNIST digit classification Evaluation using prediction results and labels"
|
||||||
|
)
|
||||||
|
parser.add_argument('--label-dir', required=True, help='Directory to labels')
|
||||||
|
parser.add_argument('--predictions', required=True, help='Path to prediction results numpy-array-like file')
|
||||||
|
parser.add_argument('--output-dir', required=True, help='Directory to evaluation result')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = setup_config()
|
||||||
|
|
||||||
|
(_, _), (_, labels) = load_data(args.label_dir)
|
||||||
|
predictions = np.load(args.predictions)
|
||||||
|
|
||||||
|
labels = np.reshape(labels, (-1,))
|
||||||
|
predictions = np.reshape(predictions, (-1,))
|
||||||
|
|
||||||
|
accuracy = (labels == predictions).sum() / labels.shape[0]
|
||||||
|
print(f'Accuracy is {accuracy}')
|
||||||
|
|
||||||
|
if args.output_dir:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
result_file = os.path.join(args.output_dir, "result.json")
|
||||||
|
with open(result_file, 'w') as f:
|
||||||
|
json.dump({'accuracy': accuracy}, f)
|
||||||
|
print(f"Saved evaluation result to {result_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,126 @@
|
||||||
|
import argparse
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
num_workers = 1
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(data_dir):
|
||||||
|
files = [
|
||||||
|
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
|
||||||
|
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
|
||||||
|
]
|
||||||
|
paths = []
|
||||||
|
for f in files:
|
||||||
|
paths.append(os.path.join(data_dir, f))
|
||||||
|
with gzip.open(paths[0], 'rb') as f:
|
||||||
|
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||||
|
with gzip.open(paths[1], 'rb') as f:
|
||||||
|
x_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
|
||||||
|
with gzip.open(paths[2], 'rb') as f:
|
||||||
|
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||||
|
with gzip.open(paths[3], 'rb') as f:
|
||||||
|
x_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
|
||||||
|
return (x_train, y_train), (x_test, y_test)
|
||||||
|
|
||||||
|
|
||||||
|
def mnist_test_images(batch_size=64, data_dir=None):
|
||||||
|
# load dataset
|
||||||
|
if data_dir:
|
||||||
|
print(f'Loading mnist data from {data_dir}')
|
||||||
|
(x_train, y_train), (x_test, y_test) = load_data(data_dir)
|
||||||
|
else:
|
||||||
|
print('Loading mnist data from tf.keras.datasets.mnist')
|
||||||
|
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
||||||
|
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||||
|
|
||||||
|
test_images = tf.data.Dataset.from_tensor_slices(x_test).batch(batch_size)
|
||||||
|
|
||||||
|
return test_images
|
||||||
|
|
||||||
|
|
||||||
|
def get_strategy(strategy='off'):
|
||||||
|
strategy = strategy.lower()
|
||||||
|
# multiple nodes, every nodes have multiple GPUs
|
||||||
|
if strategy == "multi_worker_mirrored":
|
||||||
|
return tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||||
|
# single node with multiple GPUs
|
||||||
|
if strategy == "mirrored":
|
||||||
|
return tf.distribute.MirroredStrategy()
|
||||||
|
# single node with single GPU
|
||||||
|
return tf.distribute.get_strategy()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_env(args):
|
||||||
|
tf.config.set_soft_device_placement(True)
|
||||||
|
|
||||||
|
# limit the gpu memory usage as much as it need.
|
||||||
|
try:
|
||||||
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||||
|
for gpu in gpus:
|
||||||
|
tf.config.experimental.set_memory_growth(gpu, True)
|
||||||
|
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||||
|
print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
if args.strategy == 'multi_worker_mirrored':
|
||||||
|
index = int(os.environ['VK_TASK_INDEX'])
|
||||||
|
task_name = os.environ["VC_TASK_NAME"].upper()
|
||||||
|
ips = os.environ[f'VC_{task_name}_HOSTS']
|
||||||
|
ips = ips.split(',')
|
||||||
|
global num_workers
|
||||||
|
num_workers = len(ips)
|
||||||
|
ips = [f'{ip}:20000' for ip in ips]
|
||||||
|
os.environ["TF_CONFIG"] = json.dumps({
|
||||||
|
"cluster": {
|
||||||
|
"worker": ips
|
||||||
|
},
|
||||||
|
"task": {"type": "worker", "index": index}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def setup_config():
|
||||||
|
parser = argparse.ArgumentParser(description='MNIST digits classification using trained model')
|
||||||
|
parser.add_argument('--model-dir', help='Directory to model')
|
||||||
|
parser.add_argument('--data-dir', help='Directory to MNIST dataset')
|
||||||
|
parser.add_argument('--output-dir', help='Directory to save models and logs')
|
||||||
|
parser.add_argument(
|
||||||
|
'--strategy',
|
||||||
|
default='off',
|
||||||
|
choices=['off', 'mirrored', 'multi_worker_mirrored'],
|
||||||
|
help='TensorFlow distributed training strategies'
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = setup_config()
|
||||||
|
# tf2 limitation: Collective ops must be configured at program startup
|
||||||
|
strategy = get_strategy(args.strategy)
|
||||||
|
setup_env(args)
|
||||||
|
|
||||||
|
with strategy.scope():
|
||||||
|
test_images = mnist_test_images(batch_size=64 * num_workers, data_dir=args.data_dir)
|
||||||
|
model = tf.keras.models.load_model(args.model_dir)
|
||||||
|
model.summary()
|
||||||
|
|
||||||
|
logits = model.predict(test_images, verbose=2)
|
||||||
|
probabilities = tf.nn.softmax(logits).numpy()
|
||||||
|
predictions = np.argmax(probabilities, 1)
|
||||||
|
|
||||||
|
if args.output_dir:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
prediction_results_file = os.path.join(args.output_dir, "prediction_results.npy")
|
||||||
|
np.save(prediction_results_file, predictions)
|
||||||
|
print(f"Saved prediction results to {prediction_results_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print("TensorFlow version:", tf.__version__)
|
||||||
|
main()
|
Binary file not shown.
|
@ -0,0 +1,26 @@
|
||||||
|
FROM tensorflow/tensorflow:2.4.3-gpu
|
||||||
|
|
||||||
|
USER root
|
||||||
|
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 SHELL=/bin/bash
|
||||||
|
|
||||||
|
RUN sed -i "s@http://.*archive.ubuntu.com@http://mirrors.tuna.tsinghua.edu.cn@g" /etc/apt/sources.list && \
|
||||||
|
sed -i "s@http://.*security.ubuntu.com@http://mirrors.tuna.tsinghua.edu.cn@g" /etc/apt/sources.list
|
||||||
|
|
||||||
|
RUN rm -rf /var/lib/apt/lists/* \
|
||||||
|
/etc/apt/sources.list.d/cuda.list \
|
||||||
|
/etc/apt/sources.list.d/nvidia-ml.list && \
|
||||||
|
apt-get update && \
|
||||||
|
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
apt-utils ca-certificates wget curl vim git openssh-server && \
|
||||||
|
ldconfig && \
|
||||||
|
apt-get clean && \
|
||||||
|
apt-get autoremove && \
|
||||||
|
rm -rf /var/lib/apt/lists/* /tmp/* ~/*
|
||||||
|
|
||||||
|
RUN python -m pip --no-cache-dir install -i https://mirrors.ustc.edu.cn/pypi/web/simple \
|
||||||
|
Flask \
|
||||||
|
Pillow
|
||||||
|
|
||||||
|
COPY . /mnist-serving/
|
||||||
|
|
||||||
|
CMD ['cd /mnist-serving; python serving.py']
|
|
@ -0,0 +1,100 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from flask import Flask
|
||||||
|
from flask import jsonify
|
||||||
|
from flask import request
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
model = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
tf.config.set_soft_device_placement(True)
|
||||||
|
|
||||||
|
# limit the gpu memory usage as much as it need.
|
||||||
|
try:
|
||||||
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||||
|
for gpu in gpus:
|
||||||
|
tf.config.experimental.set_memory_growth(gpu, True)
|
||||||
|
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||||
|
print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
global model
|
||||||
|
model = tf.keras.models.load_model(os.environ['MODEL_PATH'])
|
||||||
|
model.summary()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def hello_world():
|
||||||
|
return jsonify({'status': {'code': 0, 'msg': 'success'}, 'data': 'hello world'})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/image', methods=['POST'])
|
||||||
|
def digits_classification():
|
||||||
|
global model
|
||||||
|
if model is None:
|
||||||
|
model = load_model()
|
||||||
|
|
||||||
|
# accept files as URL, base64 encoded or file stream
|
||||||
|
if request.content_type and request.content_type.startswith('application/json'):
|
||||||
|
data = request.get_json()
|
||||||
|
image_files = None
|
||||||
|
image_base64s = data.get('base64s')
|
||||||
|
image_urls = data.get('urls')
|
||||||
|
else:
|
||||||
|
# content-type: "application/x-www-form-unlencoded" or "multipart/form-data"
|
||||||
|
data = request.form
|
||||||
|
image_files = request.files.getlist('files')
|
||||||
|
image_base64s = data.getlist('base64s')
|
||||||
|
image_urls = data.getlist('urls')
|
||||||
|
|
||||||
|
# image bytes
|
||||||
|
images, request_image_type = utils.get_request_images_as_file(
|
||||||
|
image_files=image_files,
|
||||||
|
image_base64s=image_base64s,
|
||||||
|
image_urls=image_urls
|
||||||
|
)
|
||||||
|
|
||||||
|
# preprocessing: resize, convert to gray, normalize
|
||||||
|
images = utils.mnist_preprocess(images)
|
||||||
|
|
||||||
|
logits = model.predict(images, verbose=0)
|
||||||
|
probability = tf.nn.softmax(logits).numpy()
|
||||||
|
prediction = np.argmax(probability, 1)
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
'status': {'code': 0, 'msg': 'success'},
|
||||||
|
'data': {
|
||||||
|
'probabilities': probability.tolist(),
|
||||||
|
'predictions': prediction.tolist()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.errorhandler(Exception)
|
||||||
|
def handle_unknown_error(e):
|
||||||
|
return jsonify({
|
||||||
|
'status': {'code': 500, 'msg': repr(traceback.format_exception(*sys.exc_info()))},
|
||||||
|
'data': None
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 如果服务运行于 Nginx 之类的代理下,由于代理设置的某些 HTTP 报文头,可能存在请求不能转发的情况
|
||||||
|
# 使用下面的方式能保证正常转发,但是仅在信任代理、信任请求方的情况下使用
|
||||||
|
# https://flask.palletsprojects.com/en/1.1.x/deploying/wsgi-standalone/#proxy-setups
|
||||||
|
# from werkzeug.contrib.fixers import ProxyFix
|
||||||
|
# app.wsgi_app = ProxyFix(app.wsgi_app)
|
||||||
|
model = load_model()
|
||||||
|
app.run(host='0.0.0.0', port=5000)
|
|
@ -0,0 +1,102 @@
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def download_image_as_bytes(url, max_file_size=0):
|
||||||
|
"""根据 URL 下载图片。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: string, 图片URL.
|
||||||
|
max_file_size: integer, 图片文件最大值,0则不限制图片大小。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes, 图片
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
超过最大图片文件大小限制
|
||||||
|
ImageUrlDownloadFailError 图片通过 URL 下载失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# requests, streaming request and streaming uploads
|
||||||
|
# https://requests.readthedocs.io/en/stable/user/advanced/
|
||||||
|
res = requests.get(url, allow_redirects=True, stream=True)
|
||||||
|
if res.status_code != 200:
|
||||||
|
raise Exception('ImageUrlDownloadFailError')
|
||||||
|
except:
|
||||||
|
raise Exception('ImageUrlDownloadFailError')
|
||||||
|
|
||||||
|
file_size = int(res.headers['content-length'])
|
||||||
|
if 0 < max_file_size < file_size:
|
||||||
|
raise Exception(
|
||||||
|
f'image size must less than {max_file_size / 1024 / 1024}MB, '
|
||||||
|
f'but get {file_size / 1024 / 1024}MB, url is {url}'
|
||||||
|
)
|
||||||
|
|
||||||
|
image = res.content
|
||||||
|
res.close()
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_images_as_file(
|
||||||
|
image_files=None, image_base64s=None, image_urls=None, max_file_size=0
|
||||||
|
):
|
||||||
|
"""由于允许客户端以 file, base64, url 三种方式发送图片,但同一请求只能以其中一种方式传输图片,
|
||||||
|
判断上传图片方式,并将数据转换成 bytes 类型的图片文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_files: list of object which implements `.read()`, `.tell()`, `.seek()`, such as
|
||||||
|
werkzeug.FileStorage, 图片文件
|
||||||
|
image_base64s: list of string, base64 编码后的图片文件
|
||||||
|
image_urls: list of string, 图片URL地址
|
||||||
|
max_file_size: integer, 单张图片文件大小最大值,如果超过了会抛 InvalidContentLengthError 异常
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[bytes], string], 图片文件对象,传入的图片数据类型,包括 "file", "base64", "url".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidContentLengthError, 超过最大图片文件大小限制
|
||||||
|
InvalidArgumentError, 图片参数或值为空
|
||||||
|
ImageUrlDownloadFailError, 图片通过 URL 下载失败
|
||||||
|
"""
|
||||||
|
if image_files and len(image_files) > 0 and None not in image_files:
|
||||||
|
image_bytes = [f.read() for f in image_files]
|
||||||
|
type_ = 'file'
|
||||||
|
|
||||||
|
elif image_base64s and len(image_base64s) > 0 and None not in image_base64s:
|
||||||
|
image_bytes = [base64.b64decode(s) for s in image_base64s]
|
||||||
|
type_ = 'base64'
|
||||||
|
|
||||||
|
elif image_urls and len(image_urls) > 0 and None not in image_urls:
|
||||||
|
image_bytes = [download_image_as_bytes(url, max_file_size=max_file_size)
|
||||||
|
for url in image_urls]
|
||||||
|
type_ = 'url'
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception('no image')
|
||||||
|
|
||||||
|
if max_file_size > 0:
|
||||||
|
for idx, f in enumerate(image_bytes):
|
||||||
|
if len(f) > max_file_size:
|
||||||
|
raise Exception(f'image size must less than {max_file_size / 1024 / 1024}MB, '
|
||||||
|
f'but image-{idx + 1} is {len(f) / 1024 / 1024}MB.')
|
||||||
|
return image_bytes, type_
|
||||||
|
|
||||||
|
|
||||||
|
def mnist_preprocess(images):
|
||||||
|
results = []
|
||||||
|
for img in images:
|
||||||
|
if isinstance(img, bytes):
|
||||||
|
image_file = BytesIO(img)
|
||||||
|
else:
|
||||||
|
image_file = img
|
||||||
|
|
||||||
|
image = Image.open(image_file)
|
||||||
|
image = image.resize((28, 28))
|
||||||
|
image = image.convert('L')
|
||||||
|
image = np.asarray(image)[:, :] / 255.
|
||||||
|
results.append(image)
|
||||||
|
return np.asarray(results)
|
|
@ -0,0 +1,148 @@
|
||||||
|
import argparse
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
num_workers = 1
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(data_dir):
|
||||||
|
files = [
|
||||||
|
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
|
||||||
|
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
|
||||||
|
]
|
||||||
|
paths = []
|
||||||
|
for f in files:
|
||||||
|
paths.append(os.path.join(data_dir, f))
|
||||||
|
with gzip.open(paths[0], 'rb') as f:
|
||||||
|
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||||
|
with gzip.open(paths[1], 'rb') as f:
|
||||||
|
x_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
|
||||||
|
with gzip.open(paths[2], 'rb') as f:
|
||||||
|
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||||
|
with gzip.open(paths[3], 'rb') as f:
|
||||||
|
x_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
|
||||||
|
return (x_train, y_train), (x_test, y_test)
|
||||||
|
|
||||||
|
|
||||||
|
def mnist_dataset(batch_size=64, data_dir=None):
|
||||||
|
# load dataset
|
||||||
|
if data_dir:
|
||||||
|
print(f'Loading mnist data from {data_dir}')
|
||||||
|
(x_train, y_train), (x_test, y_test) = load_data(data_dir)
|
||||||
|
else:
|
||||||
|
print('Loading mnist data from tf.keras.datasets.mnist')
|
||||||
|
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
||||||
|
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||||
|
|
||||||
|
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
|
||||||
|
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
|
||||||
|
|
||||||
|
return train_dataset, test_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_strategy(strategy='off'):
|
||||||
|
strategy = strategy.lower()
|
||||||
|
# multiple nodes, every nodes have multiple GPUs
|
||||||
|
if strategy == "multi_worker_mirrored":
|
||||||
|
return tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||||
|
# single node with multiple GPUs
|
||||||
|
if strategy == "mirrored":
|
||||||
|
return tf.distribute.MirroredStrategy()
|
||||||
|
# single node with single GPU
|
||||||
|
return tf.distribute.get_strategy()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_env(args):
|
||||||
|
tf.config.set_soft_device_placement(True)
|
||||||
|
|
||||||
|
# limit the gpu memory usage as much as it need.
|
||||||
|
try:
|
||||||
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||||
|
for gpu in gpus:
|
||||||
|
tf.config.experimental.set_memory_growth(gpu, True)
|
||||||
|
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||||
|
print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
if args.strategy == 'multi_worker_mirrored':
|
||||||
|
index = int(os.environ['VK_TASK_INDEX'])
|
||||||
|
task_name = os.environ["VC_TASK_NAME"].upper()
|
||||||
|
ips = os.environ[f'VC_{task_name}_HOSTS']
|
||||||
|
ips = ips.split(',')
|
||||||
|
global num_workers
|
||||||
|
num_workers = len(ips)
|
||||||
|
ips = [f'{ip}:20000' for ip in ips]
|
||||||
|
os.environ["TF_CONFIG"] = json.dumps({
|
||||||
|
"cluster": {
|
||||||
|
"worker": ips
|
||||||
|
},
|
||||||
|
"task": {"type": "worker", "index": index}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def setup_config():
|
||||||
|
parser = argparse.ArgumentParser(description='Train MNIST digits classification')
|
||||||
|
parser.add_argument('--data-dir', help='Directory to MNIST dataset')
|
||||||
|
parser.add_argument('--output-dir', help='Directory to save models and logs')
|
||||||
|
parser.add_argument('--epochs', default=2, help='Number of epochs')
|
||||||
|
parser.add_argument('--eval', action='store_true', help='whether do evaluation after training finished')
|
||||||
|
parser.add_argument(
|
||||||
|
'--strategy',
|
||||||
|
default='off',
|
||||||
|
choices=['off', 'mirrored', 'multi_worker_mirrored'],
|
||||||
|
help='TensorFlow distributed training strategies'
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = setup_config()
|
||||||
|
# tf2 limitation: Collective ops must be configured at program startup
|
||||||
|
strategy = get_strategy(args.strategy)
|
||||||
|
setup_env(args)
|
||||||
|
|
||||||
|
with strategy.scope():
|
||||||
|
# build dataset
|
||||||
|
train_dataset, test_dataset = mnist_dataset(
|
||||||
|
batch_size=64 * num_workers,
|
||||||
|
data_dir=args.data_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# build model
|
||||||
|
model = tf.keras.models.Sequential([
|
||||||
|
tf.keras.layers.Flatten(input_shape=(28, 28)),
|
||||||
|
tf.keras.layers.Dense(128, activation='relu'),
|
||||||
|
tf.keras.layers.Dropout(0.2),
|
||||||
|
tf.keras.layers.Dense(10)
|
||||||
|
])
|
||||||
|
model.compile(
|
||||||
|
optimizer='adam',
|
||||||
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||||
|
metrics=['accuracy']
|
||||||
|
)
|
||||||
|
model.summary()
|
||||||
|
|
||||||
|
# training
|
||||||
|
model.fit(train_dataset, epochs=args.epochs, steps_per_epoch=70)
|
||||||
|
|
||||||
|
# evaluation
|
||||||
|
if args.eval:
|
||||||
|
model.evaluate(test_dataset, verbose=2)
|
||||||
|
|
||||||
|
# save model
|
||||||
|
if args.output_dir:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
model.save(args.output_dir)
|
||||||
|
print(f'Saved model to {args.output_dir}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# python train.py --data-dir=/path/to/MNIST/dataset --output-dir=/path/to/output
|
||||||
|
print("TensorFlow version:", tf.__version__)
|
||||||
|
main()
|
Loading…
Reference in New Issue