Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable .numpy() and .tolist() for tensor subclasses subclasses and f… #66642

Merged
merged 3 commits into from Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion test/test_python_dispatch.py
Expand Up @@ -448,6 +448,15 @@ def test_nested_enable_python_mode(self) -> None:
with enable_python_mode(LoggingTensor):
with enable_python_mode(LoggingTensor):
pass


malfet marked this conversation as resolved.
Show resolved Hide resolved
def test_tolist_numpy_with_python_mode(self) -> None:
x = LoggingTensor(torch.tensor([2.0, 3.0]))
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
x.tolist()
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
x.numpy()
with self.assertRaises(AssertionError):
self.assertEqual(x, None)

malfet marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == '__main__':
run_tests()
7 changes: 7 additions & 0 deletions test/test_torch.py
Expand Up @@ -8399,6 +8399,13 @@ def generate_inputs(num_batches):
finally:
torch.set_num_threads(num_threads)

def test_conj_neg_tolist(self):
x = torch.randn(2, dtype=torch.cfloat)
y1 = x.conj()
y1_expect = x.conj_physical()
y2 = y1.imag
self.assertEqual(y1, y1_expect.tolist())
self.assertEqual(y2, y1_expect.imag.tolist())

# TODO: these empy classes are temporarily instantiated for XLA compatibility
# once XLA updates their test suite it should be removed
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/utils/tensor_list.cpp
Expand Up @@ -30,7 +30,8 @@ static PyObject* recursive_to_list(
}

PyObject* tensor_to_list(const Tensor& tensor) {
Tensor data = tensor;
TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".tolist() is not supported for tensor subclasses.");
Tensor data = tensor.resolve_conj().resolve_neg();
if (!data.device().is_cpu()) {
pybind11::gil_scoped_release no_gil;
data = data.toBackend(Backend::CPU);
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/utils/tensor_numpy.cpp
Expand Up @@ -130,6 +130,8 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
"Can't call numpy() on Tensor that has negative bit set. "
"Use tensor.resolve_neg().numpy() instead.");

TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".numpy() is not supported for tensor subclasses.");

auto dtype = aten_to_numpy_dtype(tensor.scalar_type());
auto sizes = to_numpy_shape(tensor.sizes());
auto strides = to_numpy_shape(tensor.strides());
Expand Down
4 changes: 3 additions & 1 deletion torch/testing/_internal/common_utils.py
Expand Up @@ -1781,8 +1781,10 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *,
assert (atol is None) == (rtol is None), "If one of atol or rtol is specified, then the other must be too"
debug_msg: Optional[str] = None

if x is None or y is None:
self.assertTrue(x is None and y is None)
# Tensor x Number and Number x Tensor comparisons
if isinstance(x, torch.Tensor) and isinstance(y, Number):
elif isinstance(x, torch.Tensor) and isinstance(y, Number):
self.assertEqual(x.item(), y, atol=atol, rtol=rtol, msg=msg,
exact_dtype=exact_dtype, exact_device=exact_device)
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
Expand Down