axolotl
streamlines the fine-tuning of AI models, offering support for multiple configurations and architectures.
Furthermore, axolotl
provides a set of yaml
examples for almost all kinds of LLMs such as LLaMA2 family, Gemma family, LLaMA3 family, Jamba, and so on. It's recommended to navigate through the examples to get a sense about the role of each parameters, and adjust them for your specific use cases. Also, it is worth checking out all configs/parameters options with a brief description from this doc.
The example below replicates the FSDP+QLoRA on LLaMA3 70B, except that here we use Llama3 8B. You can see the config.yaml
.
Running axolotl
with dstack
is very straightforward.
First, define the train.dstack.yaml
task configuration file as follows:
type: task
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- accelerate launch -m axolotl.cli.train config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2
Note
Feel free to adjust resources
to specify the required resources.
We are using the official Docker image provided by Axolotl team (winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1
). If you want to see other images, their official repo. Note, dstack
requires the CUDA driver to be 12.1+.
To run the task, use the following command:
HUGGING_FACE_HUB_TOKEN=<...> \
WANDB_API_KEY=<...> \
dstack run . -f examples/fine-tuning/axolotl/train.dstack.yaml
To push the final fine-tuned model to Hugging Face Hub, set hub_model_id
in config.yaml
.
If you'd like to build axolot
from sources (e.g. if you intend to modify its source code), follow its installation guide.
Example:
type: task
python: 3.11
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- conda install cuda
- pip3 install torch torchvision torchaudio
- git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
- cd axolotl
- pip3 install packaging
- pip3 install -e '.[flash-attn,deepspeed]'
- accelerate launch -m axolotl.cli.train ../config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2