Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[jax2tf] Fix native lowering for modules that don't use some inputs.
We must drop the corresponding actual arguments when invoking the function and we must also be careful to compute the dim_args_specs based only on the kept inputs. As a side benefit, we now allow dimension variables that occur in more complex polynomials, as long as they also occur as trivial monomials somewhere in the input shapes for the kept arguments.
- Loading branch information
Showing
3 changed files
with
150 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters