Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: use
torch.from_numpy()
instead of torch.Tensor()
to keep dat…
…a type (#2951) When running pytorch runner with `async_run` method in production mode(not in dev mode), tensors with other type than `Float32` type is converted into `Float32` type.(Ex. `LongTensor` is converted to `FloatTensor`) And I think this bug is originated by [this](https://github.com/bentoml/BentoML/blob/3de0c9a33d595eb70f4618b8a043d66a3d48befa/bentoml/_internal/frameworks/common/pytorch.py#L170) line. I made small test to check this issue. (numpy == 1.22.3 / torch == 1.12.1) ```python import torch import numpy as np arr = np.array([1, 2, 3, 4], dtype=np.int64) t1 = torch.from_numpy(arr) t2 = torch.Tensor(arr) assert t1.dtype == torch.int64 assert t2.dtype == torch.float32 ``` Thus I think changing `torch.Tensor(ret)` to `torch.from_numpy(ret)` should resolve this bug. Please tell me if it was a intended behavior.(i.e converting every tensor into `Float32` tensor). Thank you! Co-authored-by: dante.l <dante.l@kakaocorp.com> Co-authored-by: Sean Sheng <s3sheng@gmail.com>
- Loading branch information