Skip to content

Commit

Permalink
fix(multipart): use field names in request (#3135)
Browse files Browse the repository at this point in the history
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
sauyon and aarnphm committed Oct 25, 2022
1 parent 5b7bcaa commit e987397
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
22 changes: 15 additions & 7 deletions src/bentoml/_internal/io_descriptors/multipart.py
Expand Up @@ -238,13 +238,21 @@ async def from_http_request(self, request: Request) -> dict[str, t.Any]:
f"{self.__class__.__name__} only accepts `multipart/form-data` as Content-Type header, got {ctype} instead."
) from None

to_populate = zip(
self._inputs.values(), (await populate_multipart_requests(request)).values()
)
reqs = await asyncio.gather(
*tuple(io_.from_http_request(req) for io_, req in to_populate)
)
return dict(zip(self._inputs, reqs))
form_values = await populate_multipart_requests(request)

res = {}
for field, descriptor in self._inputs.items():
if field not in form_values:
break
res[field] = descriptor.from_http_request(form_values[field])
else: # NOTE: This is similar to goto, when there is no break.
to_populate = zip(self._inputs.values(), form_values.values())
reqs = await asyncio.gather(
*tuple(io_.from_http_request(req) for io_, req in to_populate)
)
res = dict(zip(self._inputs, reqs))

return res

async def to_http_response(
self, obj: dict[str, t.Any], ctx: Context | None = None
Expand Down
14 changes: 13 additions & 1 deletion tests/e2e/bento_server_http/service.py
Expand Up @@ -124,7 +124,19 @@ async def echo_image(f: PILImage) -> NDArray[t.Any]:
input=Multipart(original=Image(), compared=Image()),
output=Multipart(img1=Image(), img2=Image()),
)
async def predict_multi_images(original: dict[str, Image], compared: dict[str, Image]):
async def predict_multi_images(original: Image, compared: Image):
output_array = await py_model.predict_multi_ndarray.async_run(
np.array(original), np.array(compared)
)
img = fromarray(output_array)
return dict(img1=img, img2=img)


@svc.api(
input=Multipart(original=Image(), compared=Image()),
output=Multipart(img1=Image(), img2=Image()),
)
async def predict_different_args(compared: Image, original: Image):
output_array = await py_model.predict_multi_ndarray.async_run(
np.array(original), np.array(compared)
)
Expand Down
32 changes: 24 additions & 8 deletions tests/e2e/bento_server_http/tests/test_io.py
Expand Up @@ -222,22 +222,38 @@ async def test_image(host: str, img_file: str):
)


@pytest.fixture(name="img_form_data")
def fixture_img_form_data(img_file: str):
with open(img_file, "rb") as f1, open(img_file, "rb") as f2:
form = aiohttp.FormData()
form.add_field("original", f1.read(), content_type="image/bmp")
form.add_field("compared", f2.read(), content_type="image/bmp")
yield form


@pytest.mark.asyncio
async def test_multipart_image_io(host: str, img_file: str):
async def test_multipart_image_io(host: str, img_form_data: aiohttp.FormData):
from starlette.datastructures import UploadFile

with open(img_file, "rb") as f1:
with open(img_file, "rb") as f2:
form = aiohttp.FormData()
form.add_field("original", f1.read(), content_type="image/bmp")
form.add_field("compared", f2.read(), content_type="image/bmp")

_, headers, body = await async_request(
"POST", f"http://{host}/predict_multi_images", data=form, assert_status=200
"POST",
f"http://{host}/predict_multi_images",
data=img_form_data,
assert_status=200,
)

form = await parse_multipart_form(headers=headers, body=body)
for _, v in form.items():
assert isinstance(v, UploadFile)
img = PILImage.open(v.file)
assert np.array(img).shape == (10, 10, 3)


@pytest.mark.asyncio
async def test_multipart_image_io(host: str, img_form_data: aiohttp.FormData):
await async_request(
"POST",
f"http://{host}/predict_different_args",
data=img_form_data,
assert_status=200,
)

0 comments on commit e987397

Please sign in to comment.