Skip to content
/ rai Public

RAI: Rust ML framework with composable transformations like JAX.

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT
Notifications You must be signed in to change notification settings

cksac/rai

Repository files navigation

RAI

Rust Docs Status Latest Version Discord

ML framework with ergonomic APIs in Rust. Lazy computation and composable transformations like JAX.

Installation

cargo add rai

Code snippets

Function transformations (jvp, vjp, grad, value_and_grad)

use rai::{grad, Cpu, Tensor, F32};

fn f(x: &Tensor) -> Tensor {
    x.sin()
}

fn main() {
    let grad_fn = grad(grad(f));
    let x = &Tensor::ones([1], F32, &Cpu);
    let grad = grad_fn(x);
    println!("{}", grad.dot_graph());
    println!("{}", grad);
}

NN Modules, Optimizer and loss functions

fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
    let logits = model.forward(input);
    let loss = softmax_cross_entropy(&logits, labels).mean(..);
    (loss, Aux(logits))
}

fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
    optimizer: &mut O,
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) {
    let vg_fn = value_and_grad(loss_fn);
    let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn((model, input, labels));
    let mut params = optimizer.step(&grads);
    eval(&params);
    model.update_params(&mut params);
}

Examples

  • linear_regression
    • cargo run --bin linear_regression --release
  • mnist
    • cargo run --bin mnist --release
    • cargo run --bin mnist --release --features=cuda
  • mnist-cnn
    • cargo run --bin mnist-cnn --release
    • cargo run --bin mnist-cnn --release --features=cuda
  • phi2
    • cargo run --bin phi2 --release
    • cargo run --bin phi2 --release --features=cuda
  • phi3
    • cargo run --bin phi3 --release
    • cargo run --bin phi3 --release --features=cuda
  • qwen2
    • cargo run --bin qwen2 --release
    • cargo run --bin qwen2 --release --features=cuda
  • gemma
    • accept license agreement in https://huggingface.co/google/gemma-2b
    • pip install huggingface_hub
    • login to hf huggingface-cli login
    • cargo run --bin gemma --release
    • cargo run --bin gemma --release --features=cuda
  • vit
    • cargo run --bin vit --release
    • cargo run --bin vit --release --features=cuda

LICENSE

This project is licensed under either of

at your option.

About

RAI: Rust ML framework with composable transformations like JAX.

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Sponsor this project

 

Packages

No packages published