Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Revamp copy_to_mps_ implementation #87475

Merged
merged 2 commits into from Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 54 additions & 31 deletions aten/src/ATen/native/mps/operations/Copy.mm
Expand Up @@ -33,6 +33,25 @@
return (void*)alignedAddress;
}

/**
* Computes number of elements one needs to transfer to preserve all the elements
*/
size_t compute_strided_size(const at::Tensor& t) {
size_t rc = 1;
if (t.numel() == 0) {
return 0;
}
for(const auto i: c10::irange(t.dim())) {
assert(t.size(i) > 0);
rc += (t.size(i) - 1) * t.stride(i);
}
return rc;
}

bool is_strided_contiguous(const at::Tensor& t) {
return compute_strided_size(t) == t.numel();
}

// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
// The shapes and dtypes are taken from dst and src, but their storage pointers are not used.
void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
Expand Down Expand Up @@ -168,56 +187,60 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
return dst_;
}

static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
// Copies tensor from cpu to mps backed by identical strided-contiguous data
static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
{
MPSStream* stream = getCurrentMPSStream();
Tensor src;

id<MTLDevice> device = MPSDevice::getInstance()->device();
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
uint64_t src_total_size = 0;

// This is weird, but sometimes this function can be called
// with contiguous destination and non-contiguous source
if (src_.is_view() || dst_.is_contiguous() != src_.is_contiguous()) {
src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
// Get the actual size of a View (takes into account the storage offset)
// For View tensors, the storage offset can be bigger than what's being reported by nbytes
src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
} else {
TORCH_INTERNAL_ASSERT(src_.strides() == dst_.strides());
src = src_;
if (src.dtype() != dst_.dtype()) {
// In case of dtype change, perform conversion on source device
src = src.to(dst_.dtype());
}
src_total_size = src.nbytes();
}

auto dst_byte_offset = dst.storage_offset() * dst.itemsize();
auto src_byte_offset = src.storage_offset() * src.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst);
const size_t size_to_copy = src.nbytes();
const void* host_src = src.storage().data();
TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size()));
const void* host_src = static_cast<char *>(src.storage().data()) + src_byte_offset;

TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src));

NSUInteger sourceOffset = 0;
@autoreleasepool {
MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
NSUInteger alignedLength = 0;
NSUInteger sourceOffset = 0;

void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength);
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
deallocator:nil];
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
if (src_.is_view() || !src_.is_contiguous())
sourceOffset += src_.storage_offset() * src_.itemsize();

stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
[sourceBuffer release];
}
}

return dst_;
static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
{
// Typecast to dst_ if needed and expand, which is a no-op
Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_);

// If src is not contiguously strided it must be cloned
// It does not mean that tensor is contiguous, but rather
// that it could be represented as 1d view
if (!is_strided_contiguous(src)) {
src = src.clone();
TORCH_INTERNAL_ASSERT(is_strided_contiguous(src));
}
Tensor dst = dst_;
bool needs_copy = false;
// If src and dst_ strides do not match, it means that
// either dst_ is not representable as 1d view or its stride order is different
// in that case create an empty storage like src, copy it to device and then do
// reshaping on the device
if (src.strides() != dst_.strides()) {
needs_copy = true;
dst = at::empty_like(src, at::device(at::kMPS));
}
copy_to_mps_stride_contig(dst, src, non_blocking && !needs_copy);
return needs_copy? dst_.copy_(dst) : dst_;
}

void copy_blit_mps(void* dst, const void* src, size_t size) {
Expand Down
34 changes: 32 additions & 2 deletions test/test_mps.py
Expand Up @@ -1596,9 +1596,9 @@ def test_storage_offset_greater_than_src_nbytes(self):
tensor_list.append(t)

for i in range(0, n_tensors - 1):
t = tensor_list[i].view(1, 784)
t = tensor_list[i].view(1, n_tensor_elems)
t_mps = t.to("mps")
self.assertEqual(t, t_mps.cpu())
self.assertEqual(t, t_mps.cpu(), f"i={i}")

# See https://github.com/pytorch/pytorch/issues/82427
# and https://github.com/pytorch/pytorch/issues/83692
Expand Down Expand Up @@ -1649,6 +1649,27 @@ def test_from_numpy_non_contiguous(self):
t_mps = torch.tensor(a, device="mps")
self.assertEqual(t_cpu, t_mps.to("cpu"))

# See https://github.com/pytorch/pytorch/issues/86954
def test_copy_non_contiguous(self):
x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
self.assertFalse(x.is_contiguous())
y = x.to('mps')
self.assertFalse(y.is_contiguous())
self.assertEqual(x, y.to('cpu'))

x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
y = x.to('mps')
self.assertEqual(x, y.to('cpu'))

x = torch.full((4, 4, 4, 4), 13, device="cpu")
y = torch.full((4, 4, 4, 4), 13, device="mps")
z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
x.permute(3, 2, 1, 0)[1::, ::2] = z
# As y is on MPS and z on CPU, this dispatches to a copy operator
y.permute(3, 2, 1, 0)[1::, ::2] = z
self.assertEqual(x, y.to('cpu'))



class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
Expand Down Expand Up @@ -5686,6 +5707,15 @@ def test_view_copy_out(self, device="mps"):
self.assertEqual(expected1, out1)
self.assertEqual(expected2, out2)

def test_detached_view_copy(self, device="mps"):
# https://github.com/pytorch/pytorch/issues/86052
x = torch.arange(2)
# .detach() makes y not a view, but contig tensor
# with non-zero offset
y = x[1].detach()
z = y.to(device)
self.assertEqual(y, z.cpu())

def test_empty_reshape(self, device="mps"):
x = torch.randn(0, 6, device=device)
self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
Expand Down