From e987397e42d45667c8361d0bd16d52061b086fb5 Mon Sep 17 00:00:00 2001 From: Sauyon Lee <2347889+sauyon@users.noreply.github.com> Date: Tue, 25 Oct 2022 00:07:40 -0700 Subject: [PATCH] fix(multipart): use field names in request (#3135) Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- .../_internal/io_descriptors/multipart.py | 22 +++++++++---- tests/e2e/bento_server_http/service.py | 14 +++++++- tests/e2e/bento_server_http/tests/test_io.py | 32 ++++++++++++++----- 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/bentoml/_internal/io_descriptors/multipart.py b/src/bentoml/_internal/io_descriptors/multipart.py index 1c9349bb3a1..4f4e4aefc76 100644 --- a/src/bentoml/_internal/io_descriptors/multipart.py +++ b/src/bentoml/_internal/io_descriptors/multipart.py @@ -238,13 +238,21 @@ async def from_http_request(self, request: Request) -> dict[str, t.Any]: f"{self.__class__.__name__} only accepts `multipart/form-data` as Content-Type header, got {ctype} instead." ) from None - to_populate = zip( - self._inputs.values(), (await populate_multipart_requests(request)).values() - ) - reqs = await asyncio.gather( - *tuple(io_.from_http_request(req) for io_, req in to_populate) - ) - return dict(zip(self._inputs, reqs)) + form_values = await populate_multipart_requests(request) + + res = {} + for field, descriptor in self._inputs.items(): + if field not in form_values: + break + res[field] = descriptor.from_http_request(form_values[field]) + else: # NOTE: This is similar to goto, when there is no break. + to_populate = zip(self._inputs.values(), form_values.values()) + reqs = await asyncio.gather( + *tuple(io_.from_http_request(req) for io_, req in to_populate) + ) + res = dict(zip(self._inputs, reqs)) + + return res async def to_http_response( self, obj: dict[str, t.Any], ctx: Context | None = None diff --git a/tests/e2e/bento_server_http/service.py b/tests/e2e/bento_server_http/service.py index b6acaf3e9a7..8785a7ec1e4 100644 --- a/tests/e2e/bento_server_http/service.py +++ b/tests/e2e/bento_server_http/service.py @@ -124,7 +124,19 @@ async def echo_image(f: PILImage) -> NDArray[t.Any]: input=Multipart(original=Image(), compared=Image()), output=Multipart(img1=Image(), img2=Image()), ) -async def predict_multi_images(original: dict[str, Image], compared: dict[str, Image]): +async def predict_multi_images(original: Image, compared: Image): + output_array = await py_model.predict_multi_ndarray.async_run( + np.array(original), np.array(compared) + ) + img = fromarray(output_array) + return dict(img1=img, img2=img) + + +@svc.api( + input=Multipart(original=Image(), compared=Image()), + output=Multipart(img1=Image(), img2=Image()), +) +async def predict_different_args(compared: Image, original: Image): output_array = await py_model.predict_multi_ndarray.async_run( np.array(original), np.array(compared) ) diff --git a/tests/e2e/bento_server_http/tests/test_io.py b/tests/e2e/bento_server_http/tests/test_io.py index a826f7d19a6..68fdb7e7acd 100644 --- a/tests/e2e/bento_server_http/tests/test_io.py +++ b/tests/e2e/bento_server_http/tests/test_io.py @@ -222,18 +222,24 @@ async def test_image(host: str, img_file: str): ) +@pytest.fixture(name="img_form_data") +def fixture_img_form_data(img_file: str): + with open(img_file, "rb") as f1, open(img_file, "rb") as f2: + form = aiohttp.FormData() + form.add_field("original", f1.read(), content_type="image/bmp") + form.add_field("compared", f2.read(), content_type="image/bmp") + yield form + + @pytest.mark.asyncio -async def test_multipart_image_io(host: str, img_file: str): +async def test_multipart_image_io(host: str, img_form_data: aiohttp.FormData): from starlette.datastructures import UploadFile - with open(img_file, "rb") as f1: - with open(img_file, "rb") as f2: - form = aiohttp.FormData() - form.add_field("original", f1.read(), content_type="image/bmp") - form.add_field("compared", f2.read(), content_type="image/bmp") - _, headers, body = await async_request( - "POST", f"http://{host}/predict_multi_images", data=form, assert_status=200 + "POST", + f"http://{host}/predict_multi_images", + data=img_form_data, + assert_status=200, ) form = await parse_multipart_form(headers=headers, body=body) @@ -241,3 +247,13 @@ async def test_multipart_image_io(host: str, img_file: str): assert isinstance(v, UploadFile) img = PILImage.open(v.file) assert np.array(img).shape == (10, 10, 3) + + +@pytest.mark.asyncio +async def test_multipart_image_io(host: str, img_form_data: aiohttp.FormData): + await async_request( + "POST", + f"http://{host}/predict_different_args", + data=img_form_data, + assert_status=200, + )