上传数字识别使用的部署代码及模型
This commit is contained in:
parent
8cb1c1df70
commit
6c51d5b254
|
@ -0,0 +1,60 @@
|
|||
import argparse
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_data(data_dir):
|
||||
files = [
|
||||
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
|
||||
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
|
||||
]
|
||||
paths = []
|
||||
for f in files:
|
||||
paths.append(os.path.join(data_dir, f))
|
||||
with gzip.open(paths[0], 'rb') as f:
|
||||
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with gzip.open(paths[1], 'rb') as f:
|
||||
x_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
|
||||
with gzip.open(paths[2], 'rb') as f:
|
||||
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with gzip.open(paths[3], 'rb') as f:
|
||||
x_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
|
||||
return (x_train, y_train), (x_test, y_test)
|
||||
|
||||
|
||||
def setup_config():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="MNIST digit classification Evaluation using prediction results and labels"
|
||||
)
|
||||
parser.add_argument('--label-dir', required=True, help='Directory to labels')
|
||||
parser.add_argument('--predictions', required=True, help='Path to prediction results numpy-array-like file')
|
||||
parser.add_argument('--output-dir', required=True, help='Directory to evaluation result')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = setup_config()
|
||||
|
||||
(_, _), (_, labels) = load_data(args.label_dir)
|
||||
predictions = np.load(args.predictions)
|
||||
|
||||
labels = np.reshape(labels, (-1,))
|
||||
predictions = np.reshape(predictions, (-1,))
|
||||
|
||||
accuracy = (labels == predictions).sum() / labels.shape[0]
|
||||
print(f'Accuracy is {accuracy}')
|
||||
|
||||
if args.output_dir:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
result_file = os.path.join(args.output_dir, "result.json")
|
||||
with open(result_file, 'w') as f:
|
||||
json.dump({'accuracy': accuracy}, f)
|
||||
print(f"Saved evaluation result to {result_file}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,126 @@
|
|||
import argparse
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
num_workers = 1
|
||||
|
||||
|
||||
def load_data(data_dir):
|
||||
files = [
|
||||
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
|
||||
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
|
||||
]
|
||||
paths = []
|
||||
for f in files:
|
||||
paths.append(os.path.join(data_dir, f))
|
||||
with gzip.open(paths[0], 'rb') as f:
|
||||
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with gzip.open(paths[1], 'rb') as f:
|
||||
x_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
|
||||
with gzip.open(paths[2], 'rb') as f:
|
||||
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with gzip.open(paths[3], 'rb') as f:
|
||||
x_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
|
||||
return (x_train, y_train), (x_test, y_test)
|
||||
|
||||
|
||||
def mnist_test_images(batch_size=64, data_dir=None):
|
||||
# load dataset
|
||||
if data_dir:
|
||||
print(f'Loading mnist data from {data_dir}')
|
||||
(x_train, y_train), (x_test, y_test) = load_data(data_dir)
|
||||
else:
|
||||
print('Loading mnist data from tf.keras.datasets.mnist')
|
||||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
||||
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||
|
||||
test_images = tf.data.Dataset.from_tensor_slices(x_test).batch(batch_size)
|
||||
|
||||
return test_images
|
||||
|
||||
|
||||
def get_strategy(strategy='off'):
|
||||
strategy = strategy.lower()
|
||||
# multiple nodes, every nodes have multiple GPUs
|
||||
if strategy == "multi_worker_mirrored":
|
||||
return tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
# single node with multiple GPUs
|
||||
if strategy == "mirrored":
|
||||
return tf.distribute.MirroredStrategy()
|
||||
# single node with single GPU
|
||||
return tf.distribute.get_strategy()
|
||||
|
||||
|
||||
def setup_env(args):
|
||||
tf.config.set_soft_device_placement(True)
|
||||
|
||||
# limit the gpu memory usage as much as it need.
|
||||
try:
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||
print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
|
||||
if args.strategy == 'multi_worker_mirrored':
|
||||
index = int(os.environ['VK_TASK_INDEX'])
|
||||
task_name = os.environ["VC_TASK_NAME"].upper()
|
||||
ips = os.environ[f'VC_{task_name}_HOSTS']
|
||||
ips = ips.split(',')
|
||||
global num_workers
|
||||
num_workers = len(ips)
|
||||
ips = [f'{ip}:20000' for ip in ips]
|
||||
os.environ["TF_CONFIG"] = json.dumps({
|
||||
"cluster": {
|
||||
"worker": ips
|
||||
},
|
||||
"task": {"type": "worker", "index": index}
|
||||
})
|
||||
|
||||
|
||||
def setup_config():
|
||||
parser = argparse.ArgumentParser(description='MNIST digits classification using trained model')
|
||||
parser.add_argument('--model-dir', help='Directory to model')
|
||||
parser.add_argument('--data-dir', help='Directory to MNIST dataset')
|
||||
parser.add_argument('--output-dir', help='Directory to save models and logs')
|
||||
parser.add_argument(
|
||||
'--strategy',
|
||||
default='off',
|
||||
choices=['off', 'mirrored', 'multi_worker_mirrored'],
|
||||
help='TensorFlow distributed training strategies'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = setup_config()
|
||||
# tf2 limitation: Collective ops must be configured at program startup
|
||||
strategy = get_strategy(args.strategy)
|
||||
setup_env(args)
|
||||
|
||||
with strategy.scope():
|
||||
test_images = mnist_test_images(batch_size=64 * num_workers, data_dir=args.data_dir)
|
||||
model = tf.keras.models.load_model(args.model_dir)
|
||||
model.summary()
|
||||
|
||||
logits = model.predict(test_images, verbose=2)
|
||||
probabilities = tf.nn.softmax(logits).numpy()
|
||||
predictions = np.argmax(probabilities, 1)
|
||||
|
||||
if args.output_dir:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
prediction_results_file = os.path.join(args.output_dir, "prediction_results.npy")
|
||||
np.save(prediction_results_file, predictions)
|
||||
print(f"Saved prediction results to {prediction_results_file}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("TensorFlow version:", tf.__version__)
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,26 @@
|
|||
FROM tensorflow/tensorflow:2.4.3-gpu
|
||||
|
||||
USER root
|
||||
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 SHELL=/bin/bash
|
||||
|
||||
RUN sed -i "s@http://.*archive.ubuntu.com@http://mirrors.tuna.tsinghua.edu.cn@g" /etc/apt/sources.list && \
|
||||
sed -i "s@http://.*security.ubuntu.com@http://mirrors.tuna.tsinghua.edu.cn@g" /etc/apt/sources.list
|
||||
|
||||
RUN rm -rf /var/lib/apt/lists/* \
|
||||
/etc/apt/sources.list.d/cuda.list \
|
||||
/etc/apt/sources.list.d/nvidia-ml.list && \
|
||||
apt-get update && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
apt-utils ca-certificates wget curl vim git openssh-server && \
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
apt-get autoremove && \
|
||||
rm -rf /var/lib/apt/lists/* /tmp/* ~/*
|
||||
|
||||
RUN python -m pip --no-cache-dir install -i https://mirrors.ustc.edu.cn/pypi/web/simple \
|
||||
Flask \
|
||||
Pillow
|
||||
|
||||
COPY . /mnist-serving/
|
||||
|
||||
CMD ['cd /mnist-serving; python serving.py']
|
|
@ -0,0 +1,100 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from flask import Flask
|
||||
from flask import jsonify
|
||||
from flask import request
|
||||
|
||||
import utils
|
||||
|
||||
app = Flask(__name__)
|
||||
model = None
|
||||
|
||||
|
||||
def load_model():
|
||||
tf.config.set_soft_device_placement(True)
|
||||
|
||||
# limit the gpu memory usage as much as it need.
|
||||
try:
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||
print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
|
||||
global model
|
||||
model = tf.keras.models.load_model(os.environ['MODEL_PATH'])
|
||||
model.summary()
|
||||
return model
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def hello_world():
|
||||
return jsonify({'status': {'code': 0, 'msg': 'success'}, 'data': 'hello world'})
|
||||
|
||||
|
||||
@app.route('/image', methods=['POST'])
|
||||
def digits_classification():
|
||||
global model
|
||||
if model is None:
|
||||
model = load_model()
|
||||
|
||||
# accept files as URL, base64 encoded or file stream
|
||||
if request.content_type and request.content_type.startswith('application/json'):
|
||||
data = request.get_json()
|
||||
image_files = None
|
||||
image_base64s = data.get('base64s')
|
||||
image_urls = data.get('urls')
|
||||
else:
|
||||
# content-type: "application/x-www-form-unlencoded" or "multipart/form-data"
|
||||
data = request.form
|
||||
image_files = request.files.getlist('files')
|
||||
image_base64s = data.getlist('base64s')
|
||||
image_urls = data.getlist('urls')
|
||||
|
||||
# image bytes
|
||||
images, request_image_type = utils.get_request_images_as_file(
|
||||
image_files=image_files,
|
||||
image_base64s=image_base64s,
|
||||
image_urls=image_urls
|
||||
)
|
||||
|
||||
# preprocessing: resize, convert to gray, normalize
|
||||
images = utils.mnist_preprocess(images)
|
||||
|
||||
logits = model.predict(images, verbose=0)
|
||||
probability = tf.nn.softmax(logits).numpy()
|
||||
prediction = np.argmax(probability, 1)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
'status': {'code': 0, 'msg': 'success'},
|
||||
'data': {
|
||||
'probabilities': probability.tolist(),
|
||||
'predictions': prediction.tolist()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.errorhandler(Exception)
|
||||
def handle_unknown_error(e):
|
||||
return jsonify({
|
||||
'status': {'code': 500, 'msg': repr(traceback.format_exception(*sys.exc_info()))},
|
||||
'data': None
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 如果服务运行于 Nginx 之类的代理下,由于代理设置的某些 HTTP 报文头,可能存在请求不能转发的情况
|
||||
# 使用下面的方式能保证正常转发,但是仅在信任代理、信任请求方的情况下使用
|
||||
# https://flask.palletsprojects.com/en/1.1.x/deploying/wsgi-standalone/#proxy-setups
|
||||
# from werkzeug.contrib.fixers import ProxyFix
|
||||
# app.wsgi_app = ProxyFix(app.wsgi_app)
|
||||
model = load_model()
|
||||
app.run(host='0.0.0.0', port=5000)
|
|
@ -0,0 +1,102 @@
|
|||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def download_image_as_bytes(url, max_file_size=0):
|
||||
"""根据 URL 下载图片。
|
||||
|
||||
Args:
|
||||
url: string, 图片URL.
|
||||
max_file_size: integer, 图片文件最大值,0则不限制图片大小。
|
||||
|
||||
Returns:
|
||||
bytes, 图片
|
||||
|
||||
Raises:
|
||||
超过最大图片文件大小限制
|
||||
ImageUrlDownloadFailError 图片通过 URL 下载失败
|
||||
"""
|
||||
try:
|
||||
# requests, streaming request and streaming uploads
|
||||
# https://requests.readthedocs.io/en/stable/user/advanced/
|
||||
res = requests.get(url, allow_redirects=True, stream=True)
|
||||
if res.status_code != 200:
|
||||
raise Exception('ImageUrlDownloadFailError')
|
||||
except:
|
||||
raise Exception('ImageUrlDownloadFailError')
|
||||
|
||||
file_size = int(res.headers['content-length'])
|
||||
if 0 < max_file_size < file_size:
|
||||
raise Exception(
|
||||
f'image size must less than {max_file_size / 1024 / 1024}MB, '
|
||||
f'but get {file_size / 1024 / 1024}MB, url is {url}'
|
||||
)
|
||||
|
||||
image = res.content
|
||||
res.close()
|
||||
return image
|
||||
|
||||
|
||||
def get_request_images_as_file(
|
||||
image_files=None, image_base64s=None, image_urls=None, max_file_size=0
|
||||
):
|
||||
"""由于允许客户端以 file, base64, url 三种方式发送图片,但同一请求只能以其中一种方式传输图片,
|
||||
判断上传图片方式,并将数据转换成 bytes 类型的图片文件。
|
||||
|
||||
Args:
|
||||
image_files: list of object which implements `.read()`, `.tell()`, `.seek()`, such as
|
||||
werkzeug.FileStorage, 图片文件
|
||||
image_base64s: list of string, base64 编码后的图片文件
|
||||
image_urls: list of string, 图片URL地址
|
||||
max_file_size: integer, 单张图片文件大小最大值,如果超过了会抛 InvalidContentLengthError 异常
|
||||
|
||||
Returns:
|
||||
Tuple[List[bytes], string], 图片文件对象,传入的图片数据类型,包括 "file", "base64", "url".
|
||||
|
||||
Raises:
|
||||
InvalidContentLengthError, 超过最大图片文件大小限制
|
||||
InvalidArgumentError, 图片参数或值为空
|
||||
ImageUrlDownloadFailError, 图片通过 URL 下载失败
|
||||
"""
|
||||
if image_files and len(image_files) > 0 and None not in image_files:
|
||||
image_bytes = [f.read() for f in image_files]
|
||||
type_ = 'file'
|
||||
|
||||
elif image_base64s and len(image_base64s) > 0 and None not in image_base64s:
|
||||
image_bytes = [base64.b64decode(s) for s in image_base64s]
|
||||
type_ = 'base64'
|
||||
|
||||
elif image_urls and len(image_urls) > 0 and None not in image_urls:
|
||||
image_bytes = [download_image_as_bytes(url, max_file_size=max_file_size)
|
||||
for url in image_urls]
|
||||
type_ = 'url'
|
||||
|
||||
else:
|
||||
raise Exception('no image')
|
||||
|
||||
if max_file_size > 0:
|
||||
for idx, f in enumerate(image_bytes):
|
||||
if len(f) > max_file_size:
|
||||
raise Exception(f'image size must less than {max_file_size / 1024 / 1024}MB, '
|
||||
f'but image-{idx + 1} is {len(f) / 1024 / 1024}MB.')
|
||||
return image_bytes, type_
|
||||
|
||||
|
||||
def mnist_preprocess(images):
|
||||
results = []
|
||||
for img in images:
|
||||
if isinstance(img, bytes):
|
||||
image_file = BytesIO(img)
|
||||
else:
|
||||
image_file = img
|
||||
|
||||
image = Image.open(image_file)
|
||||
image = image.resize((28, 28))
|
||||
image = image.convert('L')
|
||||
image = np.asarray(image)[:, :] / 255.
|
||||
results.append(image)
|
||||
return np.asarray(results)
|
|
@ -0,0 +1,148 @@
|
|||
import argparse
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
num_workers = 1
|
||||
|
||||
|
||||
def load_data(data_dir):
|
||||
files = [
|
||||
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
|
||||
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
|
||||
]
|
||||
paths = []
|
||||
for f in files:
|
||||
paths.append(os.path.join(data_dir, f))
|
||||
with gzip.open(paths[0], 'rb') as f:
|
||||
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with gzip.open(paths[1], 'rb') as f:
|
||||
x_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
|
||||
with gzip.open(paths[2], 'rb') as f:
|
||||
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with gzip.open(paths[3], 'rb') as f:
|
||||
x_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
|
||||
return (x_train, y_train), (x_test, y_test)
|
||||
|
||||
|
||||
def mnist_dataset(batch_size=64, data_dir=None):
|
||||
# load dataset
|
||||
if data_dir:
|
||||
print(f'Loading mnist data from {data_dir}')
|
||||
(x_train, y_train), (x_test, y_test) = load_data(data_dir)
|
||||
else:
|
||||
print('Loading mnist data from tf.keras.datasets.mnist')
|
||||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
||||
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
|
||||
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
|
||||
|
||||
return train_dataset, test_dataset
|
||||
|
||||
|
||||
def get_strategy(strategy='off'):
|
||||
strategy = strategy.lower()
|
||||
# multiple nodes, every nodes have multiple GPUs
|
||||
if strategy == "multi_worker_mirrored":
|
||||
return tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
# single node with multiple GPUs
|
||||
if strategy == "mirrored":
|
||||
return tf.distribute.MirroredStrategy()
|
||||
# single node with single GPU
|
||||
return tf.distribute.get_strategy()
|
||||
|
||||
|
||||
def setup_env(args):
|
||||
tf.config.set_soft_device_placement(True)
|
||||
|
||||
# limit the gpu memory usage as much as it need.
|
||||
try:
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||
print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
|
||||
if args.strategy == 'multi_worker_mirrored':
|
||||
index = int(os.environ['VK_TASK_INDEX'])
|
||||
task_name = os.environ["VC_TASK_NAME"].upper()
|
||||
ips = os.environ[f'VC_{task_name}_HOSTS']
|
||||
ips = ips.split(',')
|
||||
global num_workers
|
||||
num_workers = len(ips)
|
||||
ips = [f'{ip}:20000' for ip in ips]
|
||||
os.environ["TF_CONFIG"] = json.dumps({
|
||||
"cluster": {
|
||||
"worker": ips
|
||||
},
|
||||
"task": {"type": "worker", "index": index}
|
||||
})
|
||||
|
||||
|
||||
def setup_config():
|
||||
parser = argparse.ArgumentParser(description='Train MNIST digits classification')
|
||||
parser.add_argument('--data-dir', help='Directory to MNIST dataset')
|
||||
parser.add_argument('--output-dir', help='Directory to save models and logs')
|
||||
parser.add_argument('--epochs', default=2, help='Number of epochs')
|
||||
parser.add_argument('--eval', action='store_true', help='whether do evaluation after training finished')
|
||||
parser.add_argument(
|
||||
'--strategy',
|
||||
default='off',
|
||||
choices=['off', 'mirrored', 'multi_worker_mirrored'],
|
||||
help='TensorFlow distributed training strategies'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = setup_config()
|
||||
# tf2 limitation: Collective ops must be configured at program startup
|
||||
strategy = get_strategy(args.strategy)
|
||||
setup_env(args)
|
||||
|
||||
with strategy.scope():
|
||||
# build dataset
|
||||
train_dataset, test_dataset = mnist_dataset(
|
||||
batch_size=64 * num_workers,
|
||||
data_dir=args.data_dir
|
||||
)
|
||||
|
||||
# build model
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Flatten(input_shape=(28, 28)),
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dropout(0.2),
|
||||
tf.keras.layers.Dense(10)
|
||||
])
|
||||
model.compile(
|
||||
optimizer='adam',
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
metrics=['accuracy']
|
||||
)
|
||||
model.summary()
|
||||
|
||||
# training
|
||||
model.fit(train_dataset, epochs=args.epochs, steps_per_epoch=70)
|
||||
|
||||
# evaluation
|
||||
if args.eval:
|
||||
model.evaluate(test_dataset, verbose=2)
|
||||
|
||||
# save model
|
||||
if args.output_dir:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
model.save(args.output_dir)
|
||||
print(f'Saved model to {args.output_dir}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# python train.py --data-dir=/path/to/MNIST/dataset --output-dir=/path/to/output
|
||||
print("TensorFlow version:", tf.__version__)
|
||||
main()
|
Loading…
Reference in New Issue