Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lowering of torch.aten.kthvalue #3360

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ptrifunovic98
Copy link
Contributor

@ptrifunovic98
Copy link
Contributor Author

This is an implementation of the torch.aten.kthvalue operation. The operation first finds min K values in the input, and then takes the max of those values.
For min K values, the TMTensorTopkOp was added, which is the same as IREE Linalg_ext TopK op (https://iree.dev/reference/mlir-dialects/LinalgExt/#iree_linalg_exttopk-linalgexttopkop), or, rather the MinK version of it. The max value out of those was found using linalg generic ops.

I'm open to suggestions about improving this implementation, or adding clarification for certain parts of it.


// Linalg generic op for casting topk idx output tensor elements from i32 to
// result idx tensor element type.
auto castedIdxLinalg = rewriter.create<linalg::GenericOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically you can just simplify this by not bothering with linalg::GenericOp. If I remember correctly you can just do arith::IndexCastOp on tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @rsuderman, thank for the comments!

I tried using arith::IndexCastOp directly on tensors, but I get the following error:

error: failed to legalize operation 'arith.index_cast'
note: see current operation: %14 = "arith.index_cast"(%13) : (tensor<3xi32>) -> tensor<3xindex>
error: Module does not conform to the linalg-on-tensors backend contract. See dialect conversion legality information above.

Basically, casting from a tensor of i32 element type to a tensor with the same shape and index element type fails. The docs state that the operation can be used for conversion between scalar or vector integers/indexes, but the operation is marked as Tensorizable (as all arith operations are), so it should pass for tensors too. Do you have an idea what could have produced this error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rsuderman The problem here is that operations from the arith dialect will not be legalized by the torch-verify-linalg-on-tensors-backend-contract pass. The implementation of this pass adds arith as a legal dialect only for scalar ops:

auto isLegalScalarOp = [&](Operation *op) {
      // We recognize basic scalar ops by them having the trait "Elementwise",
      // even though we don't expect them to operate on tensors.
      return scalarConverter.isLegal(op) &&
             op->hasTrait<OpTrait::Elementwise>();
};

...

target.addDynamicallyLegalDialect<arith::ArithDialect>(isLegalScalarOp);

When I add the following to the pass:

target.addDynamicallyLegalDialect<arith::ArithDialect>(opHasLegalTypes);

the error dissapears. opHasLegalTypes is used in this pass for every dialect that works on tensors. So I guess the only way to not use linalg.generic is to modify the pass, but I don't know what other side-effects this can have, so I don't think it's the best idea for me to do this.

Is it alright if I stick with linalg.generic here?

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp Outdated Show resolved Hide resolved
lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp Outdated Show resolved Hide resolved
lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp Outdated Show resolved Hide resolved
// check if element type is float, int, or unsigned
bool isUnsigned = false;
if (!isa<mlir::FloatType>(inputElementType)) {
if (isa<mlir::IntegerType>(inputElementType)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes more sense to do

if (!isa<mlir::IntegerType>(inputElementType) {
  return rewriter.notifyMatchFailure(...);
}

// else condition

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.aten.kthvalue
2 participants