-
Notifications
You must be signed in to change notification settings - Fork 437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement lowering of torch.aten.kthvalue #3360
base: main
Are you sure you want to change the base?
Implement lowering of torch.aten.kthvalue #3360
Conversation
This is an implementation of the torch.aten.kthvalue operation. The operation first finds min K values in the input, and then takes the max of those values. I'm open to suggestions about improving this implementation, or adding clarification for certain parts of it. |
|
||
// Linalg generic op for casting topk idx output tensor elements from i32 to | ||
// result idx tensor element type. | ||
auto castedIdxLinalg = rewriter.create<linalg::GenericOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically you can just simplify this by not bothering with linalg::GenericOp
. If I remember correctly you can just do arith::IndexCastOp
on tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @rsuderman, thank for the comments!
I tried using arith::IndexCastOp
directly on tensors, but I get the following error:
error: failed to legalize operation 'arith.index_cast'
note: see current operation: %14 = "arith.index_cast"(%13) : (tensor<3xi32>) -> tensor<3xindex>
error: Module does not conform to the linalg-on-tensors backend contract. See dialect conversion legality information above.
Basically, casting from a tensor of i32 element type to a tensor with the same shape and index element type fails. The docs state that the operation can be used for conversion between scalar or vector integers/indexes, but the operation is marked as Tensorizable (as all arith operations are), so it should pass for tensors too. Do you have an idea what could have produced this error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rsuderman The problem here is that operations from the arith
dialect will not be legalized by the torch-verify-linalg-on-tensors-backend-contract
pass. The implementation of this pass adds arith
as a legal dialect only for scalar ops:
auto isLegalScalarOp = [&](Operation *op) {
// We recognize basic scalar ops by them having the trait "Elementwise",
// even though we don't expect them to operate on tensors.
return scalarConverter.isLegal(op) &&
op->hasTrait<OpTrait::Elementwise>();
};
...
target.addDynamicallyLegalDialect<arith::ArithDialect>(isLegalScalarOp);
When I add the following to the pass:
target.addDynamicallyLegalDialect<arith::ArithDialect>(opHasLegalTypes);
the error dissapears. opHasLegalTypes
is used in this pass for every dialect that works on tensors. So I guess the only way to not use linalg.generic is to modify the pass, but I don't know what other side-effects this can have, so I don't think it's the best idea for me to do this.
Is it alright if I stick with linalg.generic here?
ee0bf3f
to
ab58050
Compare
// check if element type is float, int, or unsigned | ||
bool isUnsigned = false; | ||
if (!isa<mlir::FloatType>(inputElementType)) { | ||
if (isa<mlir::IntegerType>(inputElementType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes more sense to do
if (!isa<mlir::IntegerType>(inputElementType) {
return rewriter.notifyMatchFailure(...);
}
// else condition
Closes nod-ai/SHARK-Turbine#620