Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test_linear_nvfuser #357

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

IvanYashchuk
Copy link
Collaborator

Thunder's nvFuser executor got support for prims.linear in #318. This PR updates the test

  • to use the sample generator from OpInfos
  • to reject linear execution if inputs are not bf16, fp16
  • to skip the test if the nvFuser version is not new enough

Comment on lines +869 to 870
at_least_one_tested = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does that mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a for loop over sample inputs, inside the for loop there's a continue statement to skip inputs that are not 2D. If someone goes ahead and modifies the sample input generator for linear to not generate a 2D sample this test won't test anything silently. With assert on "at least one sample input tested," there would be a failure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could just create a single sample and append to the cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a viable option.

# dtype, as all tensors should have the same dtype which is checked by
# linear_meta.
if a.dtype not in (dtypes.float16, dtypes.bfloat16):
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't aware there's such restriction.

cc'ing @Priya2698 is this an intended behavior? I thought we are just dispatching to aten now and don't see why it's rejecting full precision.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, we do have that restriction since we use fusedMultiplySum and MmaOp that only accept bf16/fp16.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This restriction will although be lifted with the new IR nodes.

Copy link
Collaborator

@jjsjann123 jjsjann123 May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got'ya.

Thanks @IvanYashchuk for catching this. @Priya2698 remember to do a version bump when that restriction is lifted, so we can have a proper behavior guarded with nvfuser_version here.

Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this example. I'll update the matmul test similarly once we have complete support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants