TrlParser simplifies our fine-tune code

In the ever-changing world of machine learning, keeping our code simple and easy to maintain is more important than ever.

That’s where TrlParser comes in—a smart extension of Hugging Face’s HfArgumentParser designed to simplify configuration management.

In this post, we’ll walk through how to use TrlParser with a YAML file (where to store all the configuration parameters) to make our training scripts cleaner and our coding life a whole lot easier.

TrlParser is a subclass of HfArgumentParser designed specifically for working with Transformer Reinforcement Learning (TRL) library. It supports parsing command-line arguments backed by Python dataclasses and it allows us to provide hyperparameters in a YAML file. With this approach, we benefit from::

  • Centralize the configuration: Maintain default values, hyperparameters, and environment variables in one neat YAML file.
  • Keep code organized: Avoid clutter by separating configuration from code logic.
  • Clarity: All your training parameters, model configurations, and environment variables are declared in one human-readable file.
  • Flexibility: Easily modify parameters without diving into the code. Command-line overrides add an extra layer of flexibility.
  • Neat Code: Your main script remains clean, focusing solely on the logic, while the configuration is managed externally.
  • Reproducibility: Documenting your settings in YAML helps with versioning and reproducibility of experiments.

The official TrlParser documentation. has some room to improve, but you can check it out for more detailed documentation.

Let’s see how you can use TrlParser with YAML in a real-world training script. Imagine you have a YAML file (config.yaml) that looks like this:

# Dataset config
dataset_id_or_path: naklecha/minecraft-question-answer-700k

# Model config
model_name_or_path: openchat/openchat_3.5
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
use_liger: true
bf16: true
tf32: false
output_dir: runs

# SFT config
use_peft: true
load_in_4bit: true
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lora_modules_to_save: ["lm_head", "embed_tokens"]
lora_r: 16
lora_alpha: 16
num_train_epochs: 1
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
learning_rate: 2.0e-4
lr_scheduler_type: constant
warmup_ratio: 0.1
logging_strategy: steps
logging_steps: 5
report_to:
- tensorboard
save_strategy: "epoch"
seed: 42
push_to_hub: false
hub_model_id: openchat_3.5_minecraft
hub_strategy: every_save

 

And below the Python script that uses three dataclasses: a custom DatasetConfig, and the built-in ModelConfig and SFTConfig from Trl. The TrlParser loads all these configurations from the YAML file, keeping the code neat and modular.

from trl import TrlParser, ModelConfig, SFTConfig
from dataclasses import dataclass

@dataclass
class DatasetConfig:
dataset_id_or_path: str

def main():
  parser = TrlParser((DatasetConfig, ModelConfig, SFTConfig))
  dataset_config, model_config, sft_config = parser.parse_args_and_config()

  print("DatasetConfig:")
  print(dataset_config)

  print("\n\nModelConfig:")
  print(model_config)

  print("\n\nSFTConfig:")
  print(sft_config)

if __name__ == '__main__':
  main()

To run the script, simply supply the YAML file using the --config flag:

python your_script.py --config config.yaml

You can also override specific parameters on the fly. For example:

python your_script.py --config config.yaml --num_train_epochs 3 

TrlParser, combined with YAML configuration, is a game-changer for fine-tuning and training scripts in the Hugging Face ecosystem. It keeps your code neat, makes parameter management super simple, and provides a flexible, reproducible way to handle configurations.

Happy fine-tuning, and enjoy writing neat, maintainable code!

Scroll to Top