Skip to content

Commit

Permalink
[MPS] Revamp copy_to_mps_ implementation (pytorch#87475)
Browse files Browse the repository at this point in the history
* [MPS] Copy from CPU always add storageOffset (pytorch#86958)

Because why wouldn't it?
Fixes pytorch#86052

Pull Request resolved: pytorch#86958
Approved by: https://github.com/kulinseth

(cherry picked from commit 13cff2e)

* [MPS] Revamp copy_to_mps_ implementation (pytorch#86956)

Tensor's view in linear storage is represented by the following parameters: `.shape`, `.stride()` and `.storage_offset()`.

Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single  [`copy(from:sourceOffset:to:destinationOffset:size:)`](https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc) call.

Modify `copy_to_mps_` function to do the following steps:
- Cast `src` tensor to dst data type if needed
- Expand `src` tensor to `dst` tensor shape
- Clone `src` tensor if it is not stride contiguous (i.e. can not be represented by `src.view(src.numel())`)
- Create an empty tensor if `dst` is not stride-contiguous or if its strides are different then potentially cloned `src` strides
- Do 1d copy for `src` to (potentiall temp) `dst`
- Finally do re-striding/copy on MPS if needed

Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor

Fixes pytorch#86954

Pull Request resolved: pytorch#86956
Approved by: https://github.com/kulinseth

(cherry picked from commit ae62cf7)
  • Loading branch information
malfet committed Oct 21, 2022
1 parent 6a8be2c commit 8569a44
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 33 deletions.
85 changes: 54 additions & 31 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@
return (void*)alignedAddress;
}

/**
* Computes number of elements one needs to transfer to preserve all the elements
*/
size_t compute_strided_size(const at::Tensor& t) {
size_t rc = 1;
if (t.numel() == 0) {
return 0;
}
for(const auto i: c10::irange(t.dim())) {
assert(t.size(i) > 0);
rc += (t.size(i) - 1) * t.stride(i);
}
return rc;
}

bool is_strided_contiguous(const at::Tensor& t) {
return compute_strided_size(t) == t.numel();
}

// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
// The shapes and dtypes are taken from dst and src, but their storage pointers are not used.
void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
Expand Down Expand Up @@ -168,56 +187,60 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
return dst_;
}

static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
// Copies tensor from cpu to mps backed by identical strided-contiguous data
static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
{
MPSStream* stream = getCurrentMPSStream();
Tensor src;

id<MTLDevice> device = MPSDevice::getInstance()->device();
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
uint64_t src_total_size = 0;

// This is weird, but sometimes this function can be called
// with contiguous destination and non-contiguous source
if (src_.is_view() || dst_.is_contiguous() != src_.is_contiguous()) {
src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
// Get the actual size of a View (takes into account the storage offset)
// For View tensors, the storage offset can be bigger than what's being reported by nbytes
src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
} else {
TORCH_INTERNAL_ASSERT(src_.strides() == dst_.strides());
src = src_;
if (src.dtype() != dst_.dtype()) {
// In case of dtype change, perform conversion on source device
src = src.to(dst_.dtype());
}
src_total_size = src.nbytes();
}

auto dst_byte_offset = dst.storage_offset() * dst.itemsize();
auto src_byte_offset = src.storage_offset() * src.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst);
const size_t size_to_copy = src.nbytes();
const void* host_src = src.storage().data();
TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size()));
const void* host_src = static_cast<char *>(src.storage().data()) + src_byte_offset;

TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src));

NSUInteger sourceOffset = 0;
@autoreleasepool {
MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
NSUInteger alignedLength = 0;
NSUInteger sourceOffset = 0;

void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength);
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
deallocator:nil];
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
if (src_.is_view() || !src_.is_contiguous())
sourceOffset += src_.storage_offset() * src_.itemsize();

stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
[sourceBuffer release];
}
}

return dst_;
static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
{
// Typecast to dst_ if needed and expand, which is a no-op
Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_);

// If src is not contiguously strided it must be cloned
// It does not mean that tensor is contiguous, but rather
// that it could be represented as 1d view
if (!is_strided_contiguous(src)) {
src = src.clone();
TORCH_INTERNAL_ASSERT(is_strided_contiguous(src));
}
Tensor dst = dst_;
bool needs_copy = false;
// If src and dst_ strides do not match, it means that
// either dst_ is not representable as 1d view or its stride order is different
// in that case create an empty storage like src, copy it to device and then do
// reshaping on the device
if (src.strides() != dst_.strides()) {
needs_copy = true;
dst = at::empty_like(src, at::device(at::kMPS));
}
copy_to_mps_stride_contig(dst, src, non_blocking && !needs_copy);
return needs_copy? dst_.copy_(dst) : dst_;
}

void copy_blit_mps(void* dst, const void* src, size_t size) {
Expand Down
34 changes: 32 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,9 +1596,9 @@ def test_storage_offset_greater_than_src_nbytes(self):
tensor_list.append(t)

for i in range(0, n_tensors - 1):
t = tensor_list[i].view(1, 784)
t = tensor_list[i].view(1, n_tensor_elems)
t_mps = t.to("mps")
self.assertEqual(t, t_mps.cpu())
self.assertEqual(t, t_mps.cpu(), f"i={i}")

# See https://github.com/pytorch/pytorch/issues/82427
# and https://github.com/pytorch/pytorch/issues/83692
Expand Down Expand Up @@ -1649,6 +1649,27 @@ def test_from_numpy_non_contiguous(self):
t_mps = torch.tensor(a, device="mps")
self.assertEqual(t_cpu, t_mps.to("cpu"))

# See https://github.com/pytorch/pytorch/issues/86954
def test_copy_non_contiguous(self):
x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
self.assertFalse(x.is_contiguous())
y = x.to('mps')
self.assertFalse(y.is_contiguous())
self.assertEqual(x, y.to('cpu'))

x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
y = x.to('mps')
self.assertEqual(x, y.to('cpu'))

x = torch.full((4, 4, 4, 4), 13, device="cpu")
y = torch.full((4, 4, 4, 4), 13, device="mps")
z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
x.permute(3, 2, 1, 0)[1::, ::2] = z
# As y is on MPS and z on CPU, this dispatches to a copy operator
y.permute(3, 2, 1, 0)[1::, ::2] = z
self.assertEqual(x, y.to('cpu'))



class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
Expand Down Expand Up @@ -5686,6 +5707,15 @@ def test_view_copy_out(self, device="mps"):
self.assertEqual(expected1, out1)
self.assertEqual(expected2, out2)

def test_detached_view_copy(self, device="mps"):
# https://github.com/pytorch/pytorch/issues/86052
x = torch.arange(2)
# .detach() makes y not a view, but contig tensor
# with non-zero offset
y = x[1].detach()
z = y.to(device)
self.assertEqual(y, z.cpu())

def test_empty_reshape(self, device="mps"):
x = torch.randn(0, 6, device=device)
self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
Expand Down

0 comments on commit 8569a44

Please sign in to comment.