Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install Jax for GPU acceleration #64

Merged
merged 2 commits into from May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/py-unittest.yml
Expand Up @@ -26,3 +26,9 @@ jobs:
python -m pip install ".[dev]"
- name: Run unittest
run: pytest
- name: Install Jax
if: matrix.os != 'windows-latest'
run: pip install jax[cpu]
- name: Run unittest with Jax
if: matrix.os != 'windows-latest'
run: pytest
42 changes: 29 additions & 13 deletions scripts/train.py
Expand Up @@ -20,6 +20,14 @@
import numpy as np
import numpy.typing as npt

jax_installed = False
try:
import jax.numpy as jnp
from jax import device_put, jit
jax_installed = True
except ModuleNotFoundError:
import numpy as jnp

EPS = np.finfo(float).eps # type: np.floating[typing.Any]


Expand Down Expand Up @@ -82,7 +90,7 @@ def pred(phis: typing.Dict[int, float],
alphas: npt.NDArray[np.float64]
y: npt.NDArray[np.int64]

alphas = np.array(list(phis.values()))
alphas = jnp.array(list(phis.values()))
y = 2 * (X[:, list(phis.keys())]
== True) - 1 # noqa (cannot replace `==` with `is`)
return y.dot(alphas) > 0
Expand Down Expand Up @@ -160,39 +168,47 @@ def fit(X_train: npt.NDArray[np.bool_],
assert (X_test.shape[0] == Y_test.shape[0]
), 'Testing entries and labels should have the same number of items.'

if jax_installed:
X_train = device_put(X_train)
Y_train = device_put(Y_train)
X_test = device_put(X_test)
Y_test = device_put(Y_test)
N_train, M_train = X_train.shape
w = np.ones(N_train) / N_train

w = jnp.ones(N_train) / N_train
YX_train = Y_train[:, None] ^ X_train
for t in range(iters):
print('=== %s ===' % (t))
if chunk_size is None:
res: npt.NDArray[np.float64] = w.dot(Y_train[:, None] ^ X_train)
res: npt.NDArray[np.float64] = w.dot(YX_train)
else:
res = np.zeros(M_train)
for i in range(0, N_train, chunk_size):
Y_train_chunk = Y_train[i:i + chunk_size]
X_train_chunk = X_train[i:i + chunk_size]
YX_train_chunk = YX_train[i:i + chunk_size]
w_chunk = w[i:i + chunk_size]
res += w_chunk.dot(Y_train_chunk[:, None] ^ X_train_chunk)
err = 0.5 - np.abs(res - 0.5)
res += w_chunk.dot(YX_train_chunk)
err = 0.5 - jnp.abs(res - 0.5)
m_best = int(err.argmin())
pol_best = res[m_best] < 0.5
err_min = err[m_best]
print('min error:\t', err_min)
print('best tree:\t', m_best)
alpha = np.log((1 - err_min) / (err_min + EPS))
alpha = jnp.log((1 - err_min) / (err_min + EPS))
phis.setdefault(m_best, 0)
phis[m_best] += alpha if pol_best else -alpha
miss = Y_train ^ X_train[:, m_best]
miss = YX_train[:, m_best]
if not pol_best:
miss = ~(miss)
w = w * np.exp(alpha * miss)
w = w * jnp.exp(alpha * miss)
w = w / w.sum()
with open(weights_filename, 'a') as f:
feature = features[m_best] if m_best < len(features) else 'BIAS'
f.write('%s\t%.3f\n' % (feature, alpha if pol_best else -alpha))
acc_train = (pred(phis, X_train) == Y_train).mean()
acc_test = (pred(phis, X_test) == Y_test).mean()
if jax_installed:
acc_train = (jit(pred)(phis, X_train) == Y_train).mean()
acc_test = (jit(pred)(phis, X_test) == Y_test).mean()
else:
acc_train = (pred(phis, X_train) == Y_train).mean()
acc_test = (pred(phis, X_test) == Y_test).mean()
print('training accuracy:\t', acc_train)
print('testing accuracy:\t', acc_test)
with open(log_filename, 'a') as f:
Expand Down