Skip to content

Latest commit

 

History

History
85 lines (56 loc) · 2.74 KB

README.md

File metadata and controls

85 lines (56 loc) · 2.74 KB

Axolotl

axolotl streamlines the fine-tuning of AI models, offering support for multiple configurations and architectures.

Furthermore, axolotl provides a set of yaml examples for almost all kinds of LLMs such as LLaMA2 family, Gemma family, LLaMA3 family, Jamba, and so on. It's recommended to navigate through the examples to get a sense about the role of each parameters, and adjust them for your specific use cases. Also, it is worth checking out all configs/parameters options with a brief description from this doc.

The example below replicates the FSDP+QLoRA on LLaMA3 70B, except that here we use Llama3 8B. You can see the config.yaml.

Running with dstack

Running axolotl with dstack is very straightforward.

First, define the train.dstack.yaml task configuration file as follows:

type: task

image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1

env:
  - HUGGING_FACE_HUB_TOKEN
  - WANDB_API_KEY

commands:
  - accelerate launch -m axolotl.cli.train config.yaml

ports:
  - 6006

resources:
  gpu:
    memory: 24GB..
    count: 2

Note

Feel free to adjust resources to specify the required resources.

We are using the official Docker image provided by Axolotl team (winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1). If you want to see other images, their official repo. Note, dstack requires the CUDA driver to be 12.1+.

To run the task, use the following command:

HUGGING_FACE_HUB_TOKEN=<...> \
WANDB_API_KEY=<...> \
dstack run . -f examples/fine-tuning/axolotl/train.dstack.yaml

To push the final fine-tuned model to Hugging Face Hub, set hub_model_id in config.yaml.

Building axolotl from sources

If you'd like to build axolot from sources (e.g. if you intend to modify its source code), follow its installation guide.

Example:

type: task

python: 3.11

env:
  - HUGGING_FACE_HUB_TOKEN
  - WANDB_API_KEY

commands:
  - conda install cuda
  - pip3 install torch torchvision torchaudio

  - git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
  - cd axolotl

  - pip3 install packaging
  - pip3 install -e '.[flash-attn,deepspeed]'
    
  - accelerate launch -m axolotl.cli.train ../config.yaml

ports:
  - 6006

resources:
  gpu:
    memory: 24GB..
    count: 2