Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle shared memory cases in MathBitFallback #66667

Merged
merged 2 commits into from Oct 15, 2021

Commits on Oct 14, 2021

  1. Handle shared memory cases in MathBithFallback (pytorch#63602)

    Summary:
    Pull Request resolved: pytorch#63602
    
    This PR fixes the case when a read and write is performed on a memory shared between mutable and (or) non-mutable arguments. Example:
    ```
    a=torch.tensor([1+1j])
    b=a.conj()
    b.add_(a) # should return tensor([2]) but returns tensor ([2-2j])
    ```
    
    The issue here is that in the conjugate fallback, we resolve the conjugation in-place for mutable arguments which can be a problem as shown above in the case when other input arguments share memory with the mutable argument(s).
    This PR fixes this issue by:
    1. first scanning through the operator input arguments and creating a vector of mutable arguments that have the conj bit set to `True` (and accordingly setting the flag `check_for_alias_with_mut_arg ` to `True` or `False`).
    2. Iterating through all the arguments. At this time we only look at the non-mutable arguments. If `check_for_alias_with_mut_arg` is set to `True`, then we iterate through `mutable_inputs` to check if the current arg tensor in question doesn't alias any of the entries in `mutable_inputs`. If yes, then we clone the non-mutable tensor arg, else we resolve the conjugation as before.
    3. Now we look through the mutable_inputs vector (which contains only mutable input tensors with conj bit set to `True`). We in-place conjugate each of the entries in the vector.
    4. Do the computation.
    5. Re-conjugate the mutable argument tensors.
    
    NOTE: `TensorLists` are not fully handled in ConjugateFallback. Please see the in-line comment for more details.
    
    Fixes pytorch#59943
    
    Test Plan: Imported from OSS
    
    Reviewed By: gmagogsfm
    
    Differential Revision: D30466905
    
    Pulled By: anjali411
    
    fbshipit-source-id: 58058e5e6481da04a12d03f743c1491942a6cc9b
    anjali411 authored and malfet committed Oct 14, 2021
    Configuration menu
    Copy the full SHA
    bb8e60f View commit details
    Browse the repository at this point in the history
  2. fix lint (pytorch#66572)

    Summary: Pull Request resolved: pytorch#66572
    
    Test Plan: Imported from OSS
    
    Reviewed By: seemethere
    
    Differential Revision: D31624043
    
    Pulled By: suo
    
    fbshipit-source-id: 9db9cee3140d78c2a2f0c937be84755206fee1dd
    suo authored and malfet committed Oct 14, 2021
    Configuration menu
    Copy the full SHA
    6583aaf View commit details
    Browse the repository at this point in the history