101 lines
3.0 KiB
Python
101 lines
3.0 KiB
Python
import os
|
|
import sys
|
|
import traceback
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from flask import Flask
|
|
from flask import jsonify
|
|
from flask import request
|
|
|
|
import utils
|
|
|
|
app = Flask(__name__)
|
|
model = None
|
|
|
|
|
|
def load_model():
|
|
tf.config.set_soft_device_placement(True)
|
|
|
|
# limit the gpu memory usage as much as it need.
|
|
try:
|
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
logical_gpus = tf.config.list_logical_devices('GPU')
|
|
print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
|
|
except RuntimeError as e:
|
|
print(e)
|
|
|
|
global model
|
|
model = tf.keras.models.load_model(os.environ['MODEL_PATH'])
|
|
model.summary()
|
|
return model
|
|
|
|
|
|
@app.route('/')
|
|
def hello_world():
|
|
return jsonify({'status': {'code': 0, 'msg': 'success'}, 'data': 'hello world'})
|
|
|
|
|
|
@app.route('/predict', methods=['POST'])
|
|
def digits_classification():
|
|
global model
|
|
if model is None:
|
|
model = load_model()
|
|
|
|
# accept files as URL, base64 encoded or file stream
|
|
if request.content_type and request.content_type.startswith('application/json'):
|
|
data = request.get_json()
|
|
image_files = None
|
|
image_base64s = data.get('base64s')
|
|
image_urls = data.get('urls')
|
|
else:
|
|
# content-type: "application/x-www-form-unlencoded" or "multipart/form-data"
|
|
data = request.form
|
|
image_files = request.files.getlist('files')
|
|
image_base64s = data.getlist('base64s')
|
|
image_urls = data.getlist('urls')
|
|
|
|
# image bytes
|
|
images, request_image_type = utils.get_request_images_as_file(
|
|
image_files=image_files,
|
|
image_base64s=image_base64s,
|
|
image_urls=image_urls
|
|
)
|
|
|
|
# preprocessing: resize, convert to gray, normalize
|
|
images = utils.mnist_preprocess(images)
|
|
|
|
logits = model.predict(images, verbose=0)
|
|
probability = tf.nn.softmax(logits).numpy()
|
|
prediction = np.argmax(probability, 1)
|
|
|
|
return jsonify(
|
|
{
|
|
'status': {'code': 0, 'msg': 'success'},
|
|
'data': {
|
|
'probabilities': probability.tolist(),
|
|
'predictions': prediction.tolist()
|
|
}
|
|
}
|
|
)
|
|
|
|
|
|
@app.errorhandler(Exception)
|
|
def handle_unknown_error(e):
|
|
return jsonify({
|
|
'status': {'code': 500, 'msg': repr(traceback.format_exception(*sys.exc_info()))},
|
|
'data': None
|
|
})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# 如果服务运行于 Nginx 之类的代理下,由于代理设置的某些 HTTP 报文头,可能存在请求不能转发的情况
|
|
# 使用下面的方式能保证正常转发,但是仅在信任代理、信任请求方的情况下使用
|
|
# https://flask.palletsprojects.com/en/1.1.x/deploying/wsgi-standalone/#proxy-setups
|
|
# from werkzeug.contrib.fixers import ProxyFix
|
|
# app.wsgi_app = ProxyFix(app.wsgi_app)
|
|
model = load_model()
|
|
app.run(host='0.0.0.0', port=5000)
|