Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlir in tensorflow training? #23

Open
Cuttstage opened this issue Dec 9, 2021 · 5 comments
Open

mlir in tensorflow training? #23

Cuttstage opened this issue Dec 9, 2021 · 5 comments

Comments

@Cuttstage
Copy link

Hi, Is it possible for using mlir to make LLVM IR in machine training like GPU support?
I can not find any code in tensroflow to use mlir turn back to tensorflow executor. Therefore, mlir is only useful for inference? I wonder if tensorflow could change make some ops into IR process and then merge IR result back to tensorflow process?

@joker-eph
Copy link
Contributor

MLIR is used in Grappler now, which is plugged into the executor. MLIR is also used more and more to implement the XLA CPU/GPU compiler (which emits LLVM), and it is used for both inference and training (it powers JAX for example).

You'd have to be more specific about what you're trying to achieve though? MLIR has been used so far for fairly low-level pieces of infrastructure inside the TensorFlow/XLA ecosystem.

@joker-eph
Copy link
Contributor

Seems like you have some misconfigured auto-reply here :)

@Cuttstage
Copy link
Author

MLIR is used in Grappler now, which is plugged into the executor. MLIR is also used more and more to implement the XLA CPU/GPU compiler (which emits LLVM), and it is used for both inference and training (it powers JAX for example).

You'd have to be more specific about what you're trying to achieve though? MLIR has been used so far for fairly low-level pieces of infrastructure inside the TensorFlow/XLA ecosystem.

Thanks for your reply. I have found some code about tfg in the lastest tensorflow repo. I am a newer to MLIR and hope to use this feature in our machine train. :)

@stellaraccident
Copy link

stellaraccident commented Dec 10, 2021

MHLO (which is what this repo contains) is the native IR of JAX, which is used heavily for training (on both CPU/GPU/TPU via XLA). However, "training" can mean many things. For example, here is a prototype of a new API we are working on for saving off an entire Jax training program (in this case, a simple mnist model) so it can be run offline via IREE on single CPU/GPU systems (which happens to include a large swath of mobile and embedded devices):

This is but one example of a training setup. If looking for distributed training, that is a more advanced topic. Also, the above is just a prototype: we are looking to finish it for everyone to use in the coming months.

@Cuttstage
Copy link
Author

MHLO (which is what this repo contains) is the native IR of JAX, which is used heavily for training (on both CPU/GPU/TPU via XLA). However, "training" can mean many things. For example, here is a prototype of a new API we are working on for saving off an entire Jax training program (in this case, a simple mnist model) so it can be run offline via IREE on single CPU/GPU systems (which happens to include a large swath of mobile and embedded devices):

This is but one example of a training setup. If looking for distributed training, that is a more advanced topic. Also, the above is just a prototype: we are looking to finish it for everyone to use in the coming months.

Thanks a lot. We will have a try. But right now our producing model is running on tf. It will be a great work to move tf to jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants