dense-mnist-tf/serving/utils.py

103 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
from io import BytesIO
import numpy as np
import requests
from PIL import Image
def download_image_as_bytes(url, max_file_size=0):
"""根据 URL 下载图片。
Args:
url: string, 图片URL.
max_file_size: integer, 图片文件最大值0则不限制图片大小。
Returns:
bytes, 图片
Raises:
超过最大图片文件大小限制
ImageUrlDownloadFailError 图片通过 URL 下载失败
"""
try:
# requests, streaming request and streaming uploads
# https://requests.readthedocs.io/en/stable/user/advanced/
res = requests.get(url, allow_redirects=True, stream=True)
if res.status_code != 200:
raise Exception('ImageUrlDownloadFailError')
except:
raise Exception('ImageUrlDownloadFailError')
file_size = int(res.headers['content-length'])
if 0 < max_file_size < file_size:
raise Exception(
f'image size must less than {max_file_size / 1024 / 1024}MB, '
f'but get {file_size / 1024 / 1024}MB, url is {url}'
)
image = res.content
res.close()
return image
def get_request_images_as_file(
image_files=None, image_base64s=None, image_urls=None, max_file_size=0
):
"""由于允许客户端以 file, base64, url 三种方式发送图片,但同一请求只能以其中一种方式传输图片,
判断上传图片方式,并将数据转换成 bytes 类型的图片文件。
Args:
image_files: list of object which implements `.read()`, `.tell()`, `.seek()`, such as
werkzeug.FileStorage, 图片文件
image_base64s: list of string, base64 编码后的图片文件
image_urls: list of string, 图片URL地址
max_file_size: integer, 单张图片文件大小最大值,如果超过了会抛 InvalidContentLengthError 异常
Returns:
Tuple[List[bytes], string], 图片文件对象,传入的图片数据类型,包括 "file", "base64", "url".
Raises:
InvalidContentLengthError, 超过最大图片文件大小限制
InvalidArgumentError, 图片参数或值为空
ImageUrlDownloadFailError, 图片通过 URL 下载失败
"""
if image_files and len(image_files) > 0 and None not in image_files:
image_bytes = [f.read() for f in image_files]
type_ = 'file'
elif image_base64s and len(image_base64s) > 0 and None not in image_base64s:
image_bytes = [base64.b64decode(s) for s in image_base64s]
type_ = 'base64'
elif image_urls and len(image_urls) > 0 and None not in image_urls:
image_bytes = [download_image_as_bytes(url, max_file_size=max_file_size)
for url in image_urls]
type_ = 'url'
else:
raise Exception('no image')
if max_file_size > 0:
for idx, f in enumerate(image_bytes):
if len(f) > max_file_size:
raise Exception(f'image size must less than {max_file_size / 1024 / 1024}MB, '
f'but image-{idx + 1} is {len(f) / 1024 / 1024}MB.')
return image_bytes, type_
def mnist_preprocess(images):
results = []
for img in images:
if isinstance(img, bytes):
image_file = BytesIO(img)
else:
image_file = img
image = Image.open(image_file)
image = image.resize((28, 28))
image = image.convert('L')
image = np.asarray(image)[:, :] / 255.
results.append(image)
return np.asarray(results)