trainPPO
Functions
|
Data collator function for grouping data batches without padding. |
|
Builds and tokenizes the dataset for training based on specified data name. |
|
Main function to train a model with PPO (Proximal Policy Optimization) based on user-defined parameters. |
Module Contents
- trainPPO.collator(data)
Data collator function for grouping data batches without padding. PPOTrainer handles padding internally based on the tokenizer settings.
- Parameters:
data (list) – List of data samples.
- Returns:
A dictionary with collated data grouped by each key in the input.
- Return type:
dict
- trainPPO.build_dataset(config, tokenizer, data_name)
Builds and tokenizes the dataset for training based on specified data name.
- Parameters:
config (PPOConfig) – Configuration for PPO training.
tokenizer (AutoTokenizer) – Tokenizer used to process and encode text data.
data_name (str) – Name of the dataset, either “Imdb” or “Anthropic-harmless”.
- Returns:
A Hugging Face Dataset object with tokenized prompts for training.
- Return type:
Dataset
- trainPPO.main(lam_list, value_list, model_name, data_name, save_path, learning_rate=1e-06, batch_size=20, mini_batch_size=2, nepoch=1)
Main function to train a model with PPO (Proximal Policy Optimization) based on user-defined parameters.
- Parameters:
lam_list (list of float) – List of lambda values for aligning specified values.
value_list (str) – Comma-separated string of values to align (or “all” for all values).
model_name (str) – Name of the model to use (e.g., “opt1.3b”).
data_name (str) – Name of the dataset to use (“Imdb” or “Anthropic-harmless”).
save_path (str) – Path to save the trained model.
learning_rate (float) – Learning rate for PPO training. Defaults to 1e-6.
batch_size (int) – Total batch size for training. Defaults to 20.
mini_batch_size (int) – Batch size for each step. Defaults to 2.
nepoch (int) – Number of training epochs. Defaults to 1.
- # Example command-line usage:
>>> python trainPPO.py --model_name="opt-1.3b" --data_name="Imdb" --value_list="all" --lam_list="0.241,0.077,0.117,0.033,0.070,0.065" --learning_rate=1e-4