Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YAML configuration for mlx_lm.lora #503

Merged
merged 19 commits into from
Mar 8, 2024
Merged

Conversation

chimezie
Copy link
Contributor

Changes lora tuning module to use YAML configuration with defaults from original command-line parameters

Supersede #235

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting this started! I suggest a few changes (which I'm happy to help with):

  1. Keep support for the existing command line arguments. It can be nice to run with just the CLI and it's useful to be back-compatible here IMO.
  2. Allow a config arg for the yaml.
  3. I think a nice behavior is to overwrite flags from the config with the CLI (so prefer the command line to the same parameter set in the config). Makes it easy to experiment without needing to update the config. Another option which I think is ok is to simply disallow setting both (if the config is provided then you can't also provide the same flag on the CLI).

The main command is `mlx_lm.lora`. The argument is a YAML file with the training parameters
in the following format:

```yaml
Copy link
Member

@awni awni Mar 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the LORA.md as it was. Instead of putting this here, let's put an example config in the examples directory (call it lora_config.yaml). https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And maybe link to the example config in this readme and mention how to use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I reverted LORA.md, added the example YAML, and added a mention of it in LORA.md

from mlx.utils import tree_flatten

from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers
from .utils import load

yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni
Copy link
Member

awni commented Mar 6, 2024

@chimezie do you plan to come back to this? Let me know if I can help here!

@chimezie
Copy link
Contributor Author

chimezie commented Mar 6, 2024

Yes, I was working on it a few hours ago and intended to finish it tomorrow (with some questions).

@chimezie
Copy link
Contributor Author

chimezie commented Mar 6, 2024

  1. Keep support for the existing command line arguments. It can be nice to run with just the CLI and it's useful to be back-compatible here IMO.

I agree. I have updated this

  1. Allow a config arg for the yaml.

There was already a positional argument for the config (it is now an optional positional argument). Is it fine as it is, or did you want it as an optional argument like (for example) -c/--config?

  1. I think a nice behavior is to overwrite flags from the config with the CLI (so prefer the command line to the same parameter set in the config). Makes it easy to experiment without needing to update the config. Another option which I think is ok is to simply disallow setting both (if the config is provided then you can't also provide the same flag on the CLI).

Ok. I went with the first option, so the CLI flags take precedence

@awni
Copy link
Member

awni commented Mar 6, 2024

Thank you for all the updates!

There was already a positional argument for the config (it is now an optional positional argument). Is it fine as it is, or did you want it as an optional argument like (for example) -c/--config?

I think it would be better as a flag -c/--config because it is more consistent with all the other arguments

@chimezie
Copy link
Contributor Author

chimezie commented Mar 6, 2024

Ok. I have updated the config option accordingly. I also ended up having to remove the defaults from the CLI argument definitions because they were overriding the YAML-based options due to their taking precedence. Currently, the order of precedence is:

  1. CLI options (if specified)
  2. YAML options (if specified)
  3. Defaults in CONFIG_DEFAULTS

It is ready for your review

@chimezie
Copy link
Contributor Author

chimezie commented Mar 6, 2024

One tricky issue came up afterwards.

Apparently, even if you define options this way:

import argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument('--sum', help='', required=False)
args = parser.parse_args()

if __name__ == '__main__':
    print(args.__dict__)

The argument dict will still have an entry for the 'sum' option, mapped to None. So, as it is currently, if the -c/--config is not provided, all the defaults will be None, even for CLI options that were not specified. What I had in mind was to do this, but I'm reluctant for no particularly good reason than that argparse should exclude options that were not specified in the CLI if they are defined to not be required:

    if args.config:
        # [..]    
    else:
        args.__dict__.update(
            {
                arg: CONFIG_DEFAULTS[arg]
                for arg, value in args.__dict__.items()
                if value is None and arg != "config"
            }
        )

Thoughts?

#Number of training steps between validations.
"steps_per_eval": 200

#Load path to resume training with the given adapter weights.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just set it to adapter.npz to make sure it works out of the box?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That file is just an example of the format of the YAML configuration and the mlx_lm/lora.py here pulls defaults from CONFIG_DEFAULTS in that file:

CONFIG_DEFAULTS = {
    "adapter_file": "adapters.npz",
   [...]
}

Since 'adapters.npz' is already the default for adapter_file there, if --adapter-file is not specified in the commandline or in the configuration (as "adapter_file") then adapter.npz will be the default.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the additions!

@awni awni merged commit 8c2cf66 into ml-explore:main Mar 8, 2024
3 checks passed
devonthomas35 pushed a commit to devonthomas35/mlx-examples that referenced this pull request Mar 11, 2024
* Convert mlx_lm.lora to use YAML configuration

* pre-commit run fixes

* Fix loading of config file

* Remove invalid YAML from doc

* Update command-line options and YAML parameter overriding, per feedback in ml-explore#503

* Minor wording change

* Positional argument

* Moved config to a (-c/--config) flag

* Removed CLI option defaults (since CLI options take precedence and their defaults are in CONFIG_DEFAULTS)

* pre-commit format updates

* Fix handling of CLI option defaults

* Prevent None values of unspecified CLI options from overwriting values from CONFIG_DEFAULTS

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
@chimezie chimezie deleted the yaml-config branch March 11, 2024 15:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants