diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 91b486374603..99183d21030c 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -33,6 +33,25 @@ return (void*)alignedAddress; } +/** + * Computes number of elements one needs to transfer to preserve all the elements + */ +size_t compute_strided_size(const at::Tensor& t) { + size_t rc = 1; + if (t.numel() == 0) { + return 0; + } + for(const auto i: c10::irange(t.dim())) { + assert(t.size(i) > 0); + rc += (t.size(i) - 1) * t.stride(i); + } + return rc; +} + +bool is_strided_contiguous(const at::Tensor& t) { + return compute_strided_size(t) == t.numel(); +} + // Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type(). // The shapes and dtypes are taken from dst and src, but their storage pointers are not used. void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, @@ -168,56 +187,60 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, return dst_; } -static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) +// Copies tensor from cpu to mps backed by identical strided-contiguous data +static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking) { MPSStream* stream = getCurrentMPSStream(); - Tensor src; - id device = MPSDevice::getInstance()->device(); - auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize(); - id destBuffer = getMTLBufferStorage(dst_); - uint64_t src_total_size = 0; - - // This is weird, but sometimes this function can be called - // with contiguous destination and non-contiguous source - if (src_.is_view() || dst_.is_contiguous() != src_.is_contiguous()) { - src = src_.to(dst_.dtype()).expand_as(dst_).contiguous(); - // Get the actual size of a View (takes into account the storage offset) - // For View tensors, the storage offset can be bigger than what's being reported by nbytes - src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()); - } else { - TORCH_INTERNAL_ASSERT(src_.strides() == dst_.strides()); - src = src_; - if (src.dtype() != dst_.dtype()) { - // In case of dtype change, perform conversion on source device - src = src.to(dst_.dtype()); - } - src_total_size = src.nbytes(); - } - + auto dst_byte_offset = dst.storage_offset() * dst.itemsize(); + auto src_byte_offset = src.storage_offset() * src.itemsize(); + id destBuffer = getMTLBufferStorage(dst); const size_t size_to_copy = src.nbytes(); - const void* host_src = src.storage().data(); - TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size())); + const void* host_src = static_cast(src.storage().data()) + src_byte_offset; + + TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src)); - NSUInteger sourceOffset = 0; @autoreleasepool { MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared; NSUInteger alignedLength = 0; + NSUInteger sourceOffset = 0; - void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength); + void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength); id sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr length:alignedLength options:options deallocator:nil]; sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); - if (src_.is_view() || !src_.is_contiguous()) - sourceOffset += src_.storage_offset() * src_.itemsize(); stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); [sourceBuffer release]; } +} - return dst_; +static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) +{ + // Typecast to dst_ if needed and expand, which is a no-op + Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_); + + // If src is not contiguously strided it must be cloned + // It does not mean that tensor is contiguous, but rather + // that it could be represented as 1d view + if (!is_strided_contiguous(src)) { + src = src.clone(); + TORCH_INTERNAL_ASSERT(is_strided_contiguous(src)); + } + Tensor dst = dst_; + bool needs_copy = false; + // If src and dst_ strides do not match, it means that + // either dst_ is not representable as 1d view or its stride order is different + // in that case create an empty storage like src, copy it to device and then do + // reshaping on the device + if (src.strides() != dst_.strides()) { + needs_copy = true; + dst = at::empty_like(src, at::device(at::kMPS)); + } + copy_to_mps_stride_contig(dst, src, non_blocking && !needs_copy); + return needs_copy? dst_.copy_(dst) : dst_; } void copy_blit_mps(void* dst, const void* src, size_t size) { diff --git a/test/test_mps.py b/test/test_mps.py index 61ca9e14c543..76716154ba05 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1596,9 +1596,9 @@ def test_storage_offset_greater_than_src_nbytes(self): tensor_list.append(t) for i in range(0, n_tensors - 1): - t = tensor_list[i].view(1, 784) + t = tensor_list[i].view(1, n_tensor_elems) t_mps = t.to("mps") - self.assertEqual(t, t_mps.cpu()) + self.assertEqual(t, t_mps.cpu(), f"i={i}") # See https://github.com/pytorch/pytorch/issues/82427 # and https://github.com/pytorch/pytorch/issues/83692 @@ -1649,6 +1649,27 @@ def test_from_numpy_non_contiguous(self): t_mps = torch.tensor(a, device="mps") self.assertEqual(t_cpu, t_mps.to("cpu")) + # See https://github.com/pytorch/pytorch/issues/86954 + def test_copy_non_contiguous(self): + x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1) + self.assertFalse(x.is_contiguous()) + y = x.to('mps') + self.assertFalse(y.is_contiguous()) + self.assertEqual(x, y.to('cpu')) + + x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2] + y = x.to('mps') + self.assertEqual(x, y.to('cpu')) + + x = torch.full((4, 4, 4, 4), 13, device="cpu") + y = torch.full((4, 4, 4, 4), 13, device="mps") + z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2] + x.permute(3, 2, 1, 0)[1::, ::2] = z + # As y is on MPS and z on CPU, this dispatches to a copy operator + y.permute(3, 2, 1, 0)[1::, ::2] = z + self.assertEqual(x, y.to('cpu')) + + class TestLogical(TestCase): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): @@ -5686,6 +5707,15 @@ def test_view_copy_out(self, device="mps"): self.assertEqual(expected1, out1) self.assertEqual(expected2, out2) + def test_detached_view_copy(self, device="mps"): + # https://github.com/pytorch/pytorch/issues/86052 + x = torch.arange(2) + # .detach() makes y not a view, but contig tensor + # with non-zero offset + y = x[1].detach() + z = y.to(device) + self.assertEqual(y, z.cpu()) + def test_empty_reshape(self, device="mps"): x = torch.randn(0, 6, device=device) self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)