Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install Jax for GPU acceleration #64

Merged
merged 2 commits into from May 23, 2022
Merged

Install Jax for GPU acceleration #64

merged 2 commits into from May 23, 2022

Conversation

tushuhei
Copy link
Member

No description provided.

@tushuhei
Copy link
Member Author

This update may bring >10x boost in model training time.

Test code

$ python scripts/load_knbc.py
$ python scripts/encode_data.py source.txt
$ time python scripts/train.py encoded_data.txt --iter=50

Test environment

  • 4 core CPU 2.2 GHz
  • 26GB RAM

Results

Current
real 4m50.838s
user 4m12.218s
sys 2m26.869s

NumPy / CPU
real 3m11.276s
user 2m46.266s
sys 2m13.110s

Jax / CPU
real 1m19.230s
user 1m0.116s
sys 1m0.292s

Jax / GPU
real 0m20.113s
user 0m16.787s
sys 0m4.564s

Execution time (2)

@tushuhei tushuhei merged commit 53ba808 into main May 23, 2022
@tushuhei tushuhei deleted the jax2 branch December 14, 2022 02:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant