Skip to content

Commit

Permalink
fix: use torch.from_numpy() instead of torch.Tensor() to keep dat…
Browse files Browse the repository at this point in the history
…a type (#2951)

When running pytorch runner with `async_run` method in production mode(not in dev mode), tensors with other type than `Float32` type is converted into `Float32` type.(Ex. `LongTensor` is converted to `FloatTensor`) And I think this bug is originated by [this](https://github.com/bentoml/BentoML/blob/3de0c9a33d595eb70f4618b8a043d66a3d48befa/bentoml/_internal/frameworks/common/pytorch.py#L170) line.

I made small test to check this issue. (numpy == 1.22.3 / torch == 1.12.1)

```python
import torch
import numpy as np

arr = np.array([1, 2, 3, 4], dtype=np.int64)
t1 = torch.from_numpy(arr)
t2 = torch.Tensor(arr)

assert t1.dtype == torch.int64
assert t2.dtype == torch.float32
```

Thus I think changing `torch.Tensor(ret)` to `torch.from_numpy(ret)` should resolve this bug. Please tell me if it was a intended behavior.(i.e converting every tensor into `Float32` tensor). Thank you!

Co-authored-by: dante.l <dante.l@kakaocorp.com>
Co-authored-by: Sean Sheng <s3sheng@gmail.com>
  • Loading branch information
3 people committed Aug 31, 2022
1 parent 43c4eae commit ef67576
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion bentoml/_internal/frameworks/common/pytorch.py
Expand Up @@ -167,7 +167,7 @@ def from_payload( # pylint: disable=arguments-differ

else:
ret = pickle.loads(payload.data)
return torch.Tensor(ret).requires_grad_(False)
return torch.from_numpy(ret).requires_grad_(False)

@classmethod
@inject
Expand Down

0 comments on commit ef67576

Please sign in to comment.