Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any plans to support graph transformations (e.g. vmap, gradient, jvp, vjp) similar to jax? #6088

Open
bmacadam-sfu opened this issue Apr 15, 2024 · 4 comments
Assignees
Labels
enhancement Request for new feature or operator question Questions about ONNX

Comments

@bmacadam-sfu
Copy link

Ask a Question

Question

I have been running into an issue, where I want to deploy a classifier-guided diffusion model with ONNX, but there is currently no easy way to export the gradient of an ONNX graph.

  1. PyTorch's ONNX exporters can't handle gradients.
  2. It seems like ONNX's 'gradient' operator only works in training mode.

I have been digging through the ONNX spec and I'm a bit confused why it's not possible to do this via a graph-to-graph transformation (at least, for a reasonably large portion of ONNX operators). This would make it a lot easier for FuncTorch transforms to be supported by the ONNX exporter (I'm mostly interested in jvp/vjp, but vmap would also be nice), or at the very least I could manually export the onnx graph of the gradient and compose it into my model using something like Graph Surgeon.

The generate_artifacts function here makes me think that some of this functionality already exists, so it might simply be a matter of exposing the functionality to your users?

Further information

  • Relevant Area: model usage

  • Is this issue related to a specific model?
    No

Notes

@bmacadam-sfu bmacadam-sfu added the question Questions about ONNX label Apr 15, 2024
@justinchuby justinchuby added the enhancement Request for new feature or operator label Apr 15, 2024
@justinchuby justinchuby self-assigned this Apr 15, 2024
@justinchuby
Copy link
Contributor

Very intriguing question. Do you have any suggestions or thoughts on how this can be done?

@bmacadam-sfu
Copy link
Author

So the most direct approach can probably be found here, which gives an algorithm for getting the gradient of a program represented as a graph. I've been playing around with a two stage approach, where you first construct the jvp-and-value, then use eval/coeval (there are some details here , where they talk about self-duality) - the trick is that coeval isn't strictly computable the way you need to use it here (you can think of it as the map that sends a to a * I(n), where R^{n \times n} is R^n \otimes R^n), so you end up with another graph rewriting problem where you're trying to eliminate to coeval operation - if it can be eliminated you can construct the reverse derivative in ONNX, if it can't be eliminated than the program as written cannot be exported to ONNX.

@gramalingam
Copy link
Contributor

A couple of comments:

  • onnxruntime-training does have an implementation of generating the gradient/backward-propagation graph for an input onnx model.
  • It is not an implementation of ONNX's gradient operator (which is only a proposal/spec as far as I know ... don't know of an implementation, though an implementation based on onnxruntime's should be possible.)

@gramalingam
Copy link
Contributor

The generate_artifacts function here makes me think that some of this functionality already exists, so it might simply be a matter of exposing the functionality to your users?

I think this is correct. Except that this functionality is in onnxruntime, not in onnx. If you have suggestions on a better way to expose this functionality to users, it might be useful to make that in the onnxruntime repo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Request for new feature or operator question Questions about ONNX
Projects
None yet
Development

No branches or pull requests

3 participants