From 28f0bc320637851a158cafaed90ef15bb87fe82a Mon Sep 17 00:00:00 2001 From: Zhao Shenyang Date: Thu, 14 Jul 2022 19:18:29 +0800 Subject: [PATCH] fix(framework): fix how pytorch DataContainer convert GPU tensor (#2739) --- bentoml/_internal/frameworks/common/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bentoml/_internal/frameworks/common/pytorch.py b/bentoml/_internal/frameworks/common/pytorch.py index 38e5627948a..7413f3101b2 100644 --- a/bentoml/_internal/frameworks/common/pytorch.py +++ b/bentoml/_internal/frameworks/common/pytorch.py @@ -144,7 +144,7 @@ def to_payload( # pylint: disable=arguments-differ batch_dim: int = 0, plasma_db: "ext.PlasmaClient" | None = Provide[BentoMLContainer.plasma_db], ) -> Payload: - batch = batch.numpy() + batch = batch.cpu().numpy() if plasma_db: return cls.create_payload( plasma_db.put(batch).binary(),