Skip to content

Commit

Permalink
feat(io/image): allow restricting mime types (#2999)
Browse files Browse the repository at this point in the history
  • Loading branch information
sauyon committed Sep 18, 2022
1 parent 9d86cef commit fd39950
Showing 1 changed file with 102 additions and 18 deletions.
120 changes: 102 additions & 18 deletions bentoml/_internal/io_descriptors/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import io
import typing as t
import functools
from typing import TYPE_CHECKING
from urllib.parse import quote

from starlette.requests import Request
from multipart.multipart import parse_options_header
from starlette.responses import Response
from starlette.datastructures import UploadFile

from .base import IODescriptor
from ..types import LazyType
Expand Down Expand Up @@ -55,6 +57,24 @@
DEFAULT_PIL_MODE = "RGB"


MIME_EXT_MAPPING: dict[str, str] = None # type: ignore (lazy constant)


@functools.lru_cache(maxsize=1)
def initialize_pillow():
global MIME_EXT_MAPPING # pylint: disable=global-statement

try:
import PIL.Image
except ImportError:
raise InternalServerError(
"`Pillow` is required to use {__name__}\n Instructions: `pip install -U Pillow`"
)

PIL.Image.init()
MIME_EXT_MAPPING = {v: k for k, v in PIL.Image.MIME.items()} # type: ignore (lazy constant)


class Image(IODescriptor[ImageType]):
"""
:obj:`Image` defines API specification for the inputs/outputs of a Service, where either
Expand Down Expand Up @@ -135,36 +155,50 @@ async def predict_image(f: Image) -> NDArray[Any]:
Args:
pilmode: Color mode for PIL. Default to ``RGB``.
mime_type: Return MIME type of the :code:`starlette.response.Response`, only available when used as output descriptor.
mime_type: The MIME type of the file type that this descriptor should return. Only relevant when used as an output descriptor.
allowed_mime_types: A list of MIME types to restrict input to.
Returns:
:obj:`Image`: IO Descriptor that either a :code:`PIL.Image.Image` or a :code:`np.ndarray` representing an image.
"""

MIME_EXT_MAPPING: dict[str, str] = {}

_proto_fields = ("file",)

def __init__(
self,
pilmode: _Mode | None = DEFAULT_PIL_MODE,
mime_type: str = "image/jpeg",
*,
allowed_mime_types: t.Iterable[str] | None = None,
):
PIL.Image.init()
self.MIME_EXT_MAPPING.update({v: k for k, v in PIL.Image.MIME.items()})
initialize_pillow()

if mime_type.lower() not in self.MIME_EXT_MAPPING: # pragma: no cover
raise InvalidArgument(
f"Invalid Image mime_type '{mime_type}'. Supported mime types are {', '.join(PIL.Image.MIME.values())}."
) from None
if pilmode is not None and pilmode not in PIL.Image.MODES: # pragma: no cover
raise InvalidArgument(
f"Invalid Image pilmode '{pilmode}'. Supported PIL modes are {', '.join(PIL.Image.MODES)}."
) from None

self._mime_type = mime_type.lower()
self._allowed_mimes: set[str] = (
set(MIME_EXT_MAPPING.keys())
if allowed_mime_types is None
else {mtype.lower() for mtype in allowed_mime_types}
)
self._allow_all_images = allowed_mime_types is None

if self._mime_type not in MIME_EXT_MAPPING: # pragma: no cover
raise InvalidArgument(
f"Invalid Image mime_type '{mime_type}'; supported mime types are {', '.join(PIL.Image.MIME.values())} "
)

for mtype in self._allowed_mimes:
if mtype not in MIME_EXT_MAPPING: # pragma: no cover
raise InvalidArgument(
f"Invalid Image MIME in allowed_mime_types: '{mtype}'; supported mime types are {', '.join(PIL.Image.MIME.values())} "
)

self._pilmode: _Mode | None = pilmode
self._format = self.MIME_EXT_MAPPING[mime_type]
self._format: str = MIME_EXT_MAPPING[self._mime_type]

def input_type(self) -> UnionType:
return ImageType
Expand All @@ -177,7 +211,10 @@ def openapi_components(self) -> dict[str, t.Any] | None:

def openapi_request_body(self) -> RequestBody:
return RequestBody(
content={self._mime_type: MediaType(schema=self.openapi_schema())},
content={
mtype: MediaType(schema=self.openapi_schema())
for mtype in self._allowed_mimes
},
required=True,
)

Expand All @@ -190,19 +227,66 @@ def openapi_responses(self) -> OpenAPIResponse:
async def from_http_request(self, request: Request) -> ImageType:
content_type, _ = parse_options_header(request.headers["content-type"])
mime_type = content_type.decode().lower()

bytes_: bytes | str | None = None

if mime_type == "multipart/form-data":
form = await request.form()
bytes_ = await next(iter(form.values())).read()
elif mime_type.startswith("image/") or mime_type == self._mime_type:

found_mimes: list[str] = []

for val in form.values():
val_content_type = val.content_type # type: ignore (bad starlette types)
if isinstance(val, UploadFile):
found_mimes.append(val_content_type)

if self._allowed_mimes is None:
if (
val_content_type in MIME_EXT_MAPPING
or val_content_type.startswith("image/")
):
bytes_ = await val.read()
break
elif val_content_type in self._allowed_mimes:
bytes_ = await val.read()
break
else:
if len(found_mimes) == 0:
raise BadInput("no image file found in multipart form")
else:
if self._allowed_mimes is None:
raise BadInput(
f"no multipart image file (supported images are: {', '.join(MIME_EXT_MAPPING.keys())}, or 'image/*'), got files with content types {', '.join(found_mimes)}"
)
else:
raise BadInput(
f"no multipart image file (allowed mime types are: {', '.join(self._allowed_mimes)}), got files with content types {', '.join(found_mimes)}"
)

elif self._allowed_mimes is None:
if mime_type in MIME_EXT_MAPPING or mime_type.startswith("image/"):
bytes_ = await request.body()
elif mime_type in self._allowed_mimes:
bytes_ = await request.body()
else:
raise BadInput(
f"{self.__class__.__name__} should get 'multipart/form-data', '{self._mime_type}' or 'image/*', got '{content_type}' instead."
)
if self._allowed_mimes is None:
raise BadInput(
f"unsupported mime type {mime_type}; supported mime types are: {', '.join(MIME_EXT_MAPPING.keys())}, or 'image/*'"
)
else:
raise BadInput(
f"mime type {mime_type} is not allowed, allowed mime types are: {', '.join(self._allowed_mimes)}"
)

assert bytes_ is not None

if isinstance(bytes_, str):
bytes_ = bytes(bytes_, "UTF-8")

try:
return PIL.Image.open(io.BytesIO(bytes_))
except PIL.UnidentifiedImageError as e:
raise BadInput(f"Failed reading image file uploaded: {e}") from None
except PIL.UnidentifiedImageError: # type: ignore (bad pillow types)
raise BadInput("Failed to parse uploaded image file") from None

async def to_http_response(
self, obj: ImageType, ctx: Context | None = None
Expand Down

0 comments on commit fd39950

Please sign in to comment.