Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MHLO operation regions need to use scalars arguments #22

Open
MaheshRavishankar opened this issue Dec 8, 2021 · 6 comments
Open

MHLO operation regions need to use scalars arguments #22

MaheshRavishankar opened this issue Dec 8, 2021 · 6 comments

Comments

@MaheshRavishankar
Copy link
Contributor

MHLO operations that have regions use a zero-rank tensor to represent what are really scalar values. For example

func @reduce_one_op_all_locs_same(%arg0: tensor<?x?xf32>, %arg1 : tensor<f32>) -> (tensor<?xf32>) {
  %0 = "mhlo.reduce"(%arg0, %arg1) ( {
  ^bb0(%arg2: tensor<f32> loc("foo"), %arg3: tensor<f32> loc("foo")):
    %1 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32> loc("foo")
    "mhlo.return"(%1) : (tensor<f32>) -> () loc("foo")
  }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")

  return %0: tensor<?xf32>
}

There are a couple of issues here.

  1. The region of the mhlo.reduce here has an mhlo.add. The way one would lower mhlo.add to say linalg dialect is very different whether this operation is within an mhlo op or at the top level. This seems to be a conflation between different uses of an mhlo.add operation. It would be much easier to handle this if mhlo.add was only used at the top level and a different operation was used within mhlo operations.
  2. The region of the mhlo operation in this case seems to be a sequence of computations that are really scalars. Using tensor of zero rank introduces additional complexity when translating this to Linalg dialect since this requires a type conversion of the arguments from zero rank tensor to scalars. Having this scalar before the conversion would reduce a lot of the complexity.
@joker-eph
Copy link
Contributor

Not all reduction are scalars though. The zero-rank is just the degenerated case, but take for example (from the test-suite):

func @reduce_valid(%arg0: tensor<4x4xf32>, %arg1 : tensor<4xf32>)
    -> (tensor<4xf32>) {
  %0 = "mhlo.reduce"(%arg0, %arg1) ( {
  ^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ):
    %1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
    "mhlo.return"(%1) : (tensor<4xf32>) -> ()

  }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>

  return %0: tensor<4xf32>
}

@MaheshRavishankar
Copy link
Contributor Author

Not all reduction are scalars though. The zero-rank is just the degenerated case, but take for example (from the test-suite):

func @reduce_valid(%arg0: tensor<4x4xf32>, %arg1 : tensor<4xf32>)
    -> (tensor<4xf32>) {
  %0 = "mhlo.reduce"(%arg0, %arg1) ( {
  ^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ):
    %1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
    "mhlo.return"(%1) : (tensor<4xf32>) -> ()

  }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>

  return %0: tensor<4xf32>
}

Yes, I did notice that (and actually didnt know that this existed). Specifically such an operation cannot be lowered to Linalg directly (today). So maybe all that is needed is an MHLO -> MHLO transform before lowering to Linalg that converts the zero-rank tensor case to scalars and converts mhlo.add operations within such regions to say arith operations. I think today those would be marked illegal?

@silvasean
Copy link
Contributor

Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).

@rsuderman
Copy link
Contributor

I don't think you need the dialect conversion framework to handle this. I think the biggest issue is what operations are supported in the MHLO reduce region. I could easily see non-elementwise operations being used in the reduction region preventing lowering to linalg.

@MaheshRavishankar
Copy link
Contributor Author

Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).

Not a stakeholder in MHLO per se, but for me mhlo.reduce having a payload that is itself a tensor based operation is a "higher level abstraction". That needs to be lowered into "something else" before it can be lowered into Linalg. Simply speaking, the payload operating on tensors makes the mhlo.reduce an imperfectly nested loop nest, while it operating on scalars is a perfectly nested loop nest. A perfectly nested loop nest is a special case of an imperfectly nested loop nest, but lowering an imperfectly nested loop nest is a different starting point compared to lowering a perfectly nested loop nest.

@MaheshRavishankar
Copy link
Contributor Author

Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).

Not a stakeholder in MHLO per se, but for me mhlo.reduce having a payload that is itself a tensor based operation is a "higher level abstraction". That needs to be lowered into "something else" before it can be lowered into Linalg. Simply speaking, the payload operating on tensors makes the mhlo.reduce an imperfectly nested loop nest, while it operating on scalars is a perfectly nested loop nest. A perfectly nested loop nest is a special case of an imperfectly nested loop nest, but lowering an imperfectly nested loop nest is a different starting point compared to lowering a perfectly nested loop nest.

So if `mhlo.reduce` does support tensor operations in the payload, there needs to be further mhlo -> mhlo transformations that would be needed to get it to state where it can be lowered to Linalg (as an example).

copybara-service bot pushed a commit that referenced this issue Dec 16, 2022
Imported from GitHub PR tensorflow/tensorflow#58720

Enables scaled GEMMs based on `F8E4M3FN` and `F8E5M2` [FP8 data types](https://arxiv.org/abs/2209.05433). The pattern described by steps 1 through 6 in [RFC #22](openxla/xla#22) is rewritten into a Custom Call of the form

(A, B, a_scale, b_scale, d_scale) -> (D, d_amax),

where A, B and D are FP8 matrices and a_scale, b_scale and d_scale are their respective scaling factors. The scalar d_amax gives the maximum of the absolute values in D before rescaling and casting to FP8 and can be used in the calculation of new scaling factors.
Copybara import of the project:

--
f2eb35a9efcaaffdbb7314f99521357840bd49d8 by Philipp Hack <phack@nvidia.com>:

Support for FP8 GEMMs in XLA.

--
0afd695b3840417fdb1c00987c8c5e980be0de33 by Philipp Hack <phack@nvidia.com>:

Support for FP8 GEMMs in XLA.

--
5aba0882bc624215613c77d73dd23ec3b1d8b0d9 by Philipp Hack <phack@nvidia.com>:

Support for FP8 GEMMs in XLA.

--
8d18d22d61b1b440421fc3dd402acdaaf27519b3 by Philipp Hack <phack@nvidia.com>:

Support for FP8 GEMMs in XLA.

--
7759e0a5d041c26c632d4e433d5f544e0194ea40 by Philipp Hack <phack@nvidia.com>:

Support for FP8 GEMMs in XLA.

Merging this change closes #58720

PiperOrigin-RevId: 495806551
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants