103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
|
import base64
|
|||
|
from io import BytesIO
|
|||
|
|
|||
|
import numpy as np
|
|||
|
import requests
|
|||
|
from PIL import Image
|
|||
|
|
|||
|
|
|||
|
def download_image_as_bytes(url, max_file_size=0):
|
|||
|
"""根据 URL 下载图片。
|
|||
|
|
|||
|
Args:
|
|||
|
url: string, 图片URL.
|
|||
|
max_file_size: integer, 图片文件最大值,0则不限制图片大小。
|
|||
|
|
|||
|
Returns:
|
|||
|
bytes, 图片
|
|||
|
|
|||
|
Raises:
|
|||
|
超过最大图片文件大小限制
|
|||
|
ImageUrlDownloadFailError 图片通过 URL 下载失败
|
|||
|
"""
|
|||
|
try:
|
|||
|
# requests, streaming request and streaming uploads
|
|||
|
# https://requests.readthedocs.io/en/stable/user/advanced/
|
|||
|
res = requests.get(url, allow_redirects=True, stream=True)
|
|||
|
if res.status_code != 200:
|
|||
|
raise Exception('ImageUrlDownloadFailError')
|
|||
|
except:
|
|||
|
raise Exception('ImageUrlDownloadFailError')
|
|||
|
|
|||
|
file_size = int(res.headers['content-length'])
|
|||
|
if 0 < max_file_size < file_size:
|
|||
|
raise Exception(
|
|||
|
f'image size must less than {max_file_size / 1024 / 1024}MB, '
|
|||
|
f'but get {file_size / 1024 / 1024}MB, url is {url}'
|
|||
|
)
|
|||
|
|
|||
|
image = res.content
|
|||
|
res.close()
|
|||
|
return image
|
|||
|
|
|||
|
|
|||
|
def get_request_images_as_file(
|
|||
|
image_files=None, image_base64s=None, image_urls=None, max_file_size=0
|
|||
|
):
|
|||
|
"""由于允许客户端以 file, base64, url 三种方式发送图片,但同一请求只能以其中一种方式传输图片,
|
|||
|
判断上传图片方式,并将数据转换成 bytes 类型的图片文件。
|
|||
|
|
|||
|
Args:
|
|||
|
image_files: list of object which implements `.read()`, `.tell()`, `.seek()`, such as
|
|||
|
werkzeug.FileStorage, 图片文件
|
|||
|
image_base64s: list of string, base64 编码后的图片文件
|
|||
|
image_urls: list of string, 图片URL地址
|
|||
|
max_file_size: integer, 单张图片文件大小最大值,如果超过了会抛 InvalidContentLengthError 异常
|
|||
|
|
|||
|
Returns:
|
|||
|
Tuple[List[bytes], string], 图片文件对象,传入的图片数据类型,包括 "file", "base64", "url".
|
|||
|
|
|||
|
Raises:
|
|||
|
InvalidContentLengthError, 超过最大图片文件大小限制
|
|||
|
InvalidArgumentError, 图片参数或值为空
|
|||
|
ImageUrlDownloadFailError, 图片通过 URL 下载失败
|
|||
|
"""
|
|||
|
if image_files and len(image_files) > 0 and None not in image_files:
|
|||
|
image_bytes = [f.read() for f in image_files]
|
|||
|
type_ = 'file'
|
|||
|
|
|||
|
elif image_base64s and len(image_base64s) > 0 and None not in image_base64s:
|
|||
|
image_bytes = [base64.b64decode(s) for s in image_base64s]
|
|||
|
type_ = 'base64'
|
|||
|
|
|||
|
elif image_urls and len(image_urls) > 0 and None not in image_urls:
|
|||
|
image_bytes = [download_image_as_bytes(url, max_file_size=max_file_size)
|
|||
|
for url in image_urls]
|
|||
|
type_ = 'url'
|
|||
|
|
|||
|
else:
|
|||
|
raise Exception('no image')
|
|||
|
|
|||
|
if max_file_size > 0:
|
|||
|
for idx, f in enumerate(image_bytes):
|
|||
|
if len(f) > max_file_size:
|
|||
|
raise Exception(f'image size must less than {max_file_size / 1024 / 1024}MB, '
|
|||
|
f'but image-{idx + 1} is {len(f) / 1024 / 1024}MB.')
|
|||
|
return image_bytes, type_
|
|||
|
|
|||
|
|
|||
|
def mnist_preprocess(images):
|
|||
|
results = []
|
|||
|
for img in images:
|
|||
|
if isinstance(img, bytes):
|
|||
|
image_file = BytesIO(img)
|
|||
|
else:
|
|||
|
image_file = img
|
|||
|
|
|||
|
image = Image.open(image_file)
|
|||
|
image = image.resize((28, 28))
|
|||
|
image = image.convert('L')
|
|||
|
image = np.asarray(image)[:, :] / 255.
|
|||
|
results.append(image)
|
|||
|
return np.asarray(results)
|