Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Revamp copy_to_mps_ implementation #87475

Merged
merged 2 commits into from Oct 21, 2022
Merged

[MPS] Revamp copy_to_mps_ implementation #87475

merged 2 commits into from Oct 21, 2022

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Oct 21, 2022

Tensor's view in linear storage is represented by the following parameters: .shape, .stride() and .storage_offset().
Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single copy(from:sourceOffset:to:destinationOffset:size:) call.
Modify copy_to_mps_ function to do the following steps:

  • Cast src tensor to dst data type if needed
  • Expand src tensor to dst tensor shape
  • Clone src tensor if it is not stride contiguous (i.e. can not be represented by src.view(src.numel()))
  • Create an empty tensor if dst is not stride-contiguous or if its strides are different then potentially cloned src strides
  • Do 1d copy for src to (potentiall temp) dst
  • Finally do re-striding/copy on MPS if needed

Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor

malfet and others added 2 commits October 21, 2022 09:25
Because why wouldn't it?
Fixes #86052

Pull Request resolved: #86958
Approved by: https://github.com/kulinseth

(cherry picked from commit 13cff2e)
Tensor's view in linear storage is represented by the following parameters: `.shape`, `.stride()` and `.storage_offset()`.

Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single  [`copy(from:sourceOffset:to:destinationOffset:size:)`](https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc) call.

Modify `copy_to_mps_` function to do the following steps:
- Cast `src` tensor to dst data type if needed
- Expand `src` tensor to `dst` tensor shape
- Clone `src` tensor if it is not stride contiguous (i.e. can not be represented by `src.view(src.numel())`)
- Create an empty tensor if `dst` is not stride-contiguous or if its strides are different then potentially cloned `src` strides
- Do 1d copy for `src` to (potentiall temp) `dst`
- Finally do re-striding/copy on MPS if needed

Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor

Fixes #86954

Pull Request resolved: #86956
Approved by: https://github.com/kulinseth

(cherry picked from commit ae62cf7)
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 21, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87475

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures, 8 Pending

As of commit 1ad30ba:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Oct 21, 2022
Copy link
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 21, 2022
@malfet malfet merged commit 8569a44 into release/1.13 Oct 21, 2022
@malfet malfet deleted the malfet/cp-86954 branch October 21, 2022 17:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request release notes: mps Release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants