Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jax2tf] Fix native lowering for modules that don't use some inputs. #13603

Merged
merged 1 commit into from Dec 12, 2022

Commits on Dec 12, 2022

  1. [jax2tf] Fix native lowering for modules that don't use some inputs.

    We must drop the corresponding actual arguments when invoking the function
    and we must also be careful to compute the dim_args_specs based only
    on the kept inputs.
    
    As a side benefit, we now allow dimension variables that occur in
    more complex polynomials, as long as they also occur as trivial
    monomials somewhere in the input shapes for the kept arguments.
    gnecula committed Dec 12, 2022
    Configuration menu
    Copy the full SHA
    2f9dd04 View commit details
    Browse the repository at this point in the history