In the ever-changing world of machine learning, keeping our code simple and easy to maintain is more important than ever.
That’s where TrlParser comes in—a smart extension of Hugging Face’s HfArgumentParser designed to simplify configuration management.
In this post, we’ll walk through how to use TrlParser with a YAML file (where to store all the configuration parameters) to make our training scripts cleaner and our coding life a whole lot easier.
TrlParser is a subclass of HfArgumentParser designed specifically for working with Transformer Reinforcement Learning (TRL) library. It supports parsing command-line arguments backed by Python dataclasses and it allows us to provide hyperparameters in a YAML file. With this approach, we benefit from::
- Centralize the configuration: Maintain default values, hyperparameters, and environment variables in one neat YAML file.
- Keep code organized: Avoid clutter by separating configuration from code logic.
- Clarity: All your training parameters, model configurations, and environment variables are declared in one human-readable file.
- Flexibility: Easily modify parameters without diving into the code. Command-line overrides add an extra layer of flexibility.
- Neat Code: Your main script remains clean, focusing solely on the logic, while the configuration is managed externally.
- Reproducibility: Documenting your settings in YAML helps with versioning and reproducibility of experiments.
The official TrlParser documentation. has some room to improve, but you can check it out for more detailed documentation.
Let’s see how you can use TrlParser with YAML in a real-world training script. Imagine you have a YAML file (config.yaml
) that looks like this:
# Dataset config dataset_id_or_path: naklecha/minecraft-question-answer-700k # Model config model_name_or_path: openchat/openchat_3.5 model_revision: main torch_dtype: bfloat16 attn_implementation: flash_attention_2 use_liger: true bf16: true tf32: false output_dir: runs # SFT config use_peft: true load_in_4bit: true lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] lora_modules_to_save: ["lm_head", "embed_tokens"] lora_r: 16 lora_alpha: 16 num_train_epochs: 1 per_device_train_batch_size: 4 gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false learning_rate: 2.0e-4 lr_scheduler_type: constant warmup_ratio: 0.1 logging_strategy: steps logging_steps: 5 report_to: - tensorboard save_strategy: "epoch" seed: 42 push_to_hub: false hub_model_id: openchat_3.5_minecraft hub_strategy: every_save
And below the Python script that uses three dataclasses: a custom DatasetConfig, and the built-in ModelConfig and SFTConfig from Trl. The TrlParser loads all these configurations from the YAML file, keeping the code neat and modular.
from trl import TrlParser, ModelConfig, SFTConfig from dataclasses import dataclass @dataclass class DatasetConfig: dataset_id_or_path: str def main(): parser = TrlParser((DatasetConfig, ModelConfig, SFTConfig)) dataset_config, model_config, sft_config = parser.parse_args_and_config() print("DatasetConfig:") print(dataset_config) print("\n\nModelConfig:") print(model_config) print("\n\nSFTConfig:") print(sft_config) if __name__ == '__main__': main()
To run the script, simply supply the YAML file using the --config
flag:
python your_script.py --config config.yaml
You can also override specific parameters on the fly. For example:
python your_script.py --config config.yaml --num_train_epochs 3
TrlParser, combined with YAML configuration, is a game-changer for fine-tuning and training scripts in the Hugging Face ecosystem. It keeps your code neat, makes parameter management super simple, and provides a flexible, reproducible way to handle configurations.
Happy fine-tuning, and enjoy writing neat, maintainable code!