forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MPS] Revamp copy_to_mps_ implementation (pytorch#87475)
* [MPS] Copy from CPU always add storageOffset (pytorch#86958) Because why wouldn't it? Fixes pytorch#86052 Pull Request resolved: pytorch#86958 Approved by: https://github.com/kulinseth (cherry picked from commit 13cff2e) * [MPS] Revamp copy_to_mps_ implementation (pytorch#86956) Tensor's view in linear storage is represented by the following parameters: `.shape`, `.stride()` and `.storage_offset()`. Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single [`copy(from:sourceOffset:to:destinationOffset:size:)`](https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc) call. Modify `copy_to_mps_` function to do the following steps: - Cast `src` tensor to dst data type if needed - Expand `src` tensor to `dst` tensor shape - Clone `src` tensor if it is not stride contiguous (i.e. can not be represented by `src.view(src.numel())`) - Create an empty tensor if `dst` is not stride-contiguous or if its strides are different then potentially cloned `src` strides - Do 1d copy for `src` to (potentiall temp) `dst` - Finally do re-striding/copy on MPS if needed Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor Fixes pytorch#86954 Pull Request resolved: pytorch#86956 Approved by: https://github.com/kulinseth (cherry picked from commit ae62cf7)
- Loading branch information
Showing
2 changed files
with
86 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters