gendata
Classes
Generates text continuations based on existing prompts, either using a pretrained model or an aligned model with Monte Carlo sampling. |
Module Contents
- class gendata.TextGeneration(basemodel_name, data_name, save_directory='results')
Generates text continuations based on existing prompts, either using a pretrained model or an aligned model with Monte Carlo sampling.
- device
The device for model inference, typically ‘cuda’ if available.
- Type:
str
- basemodel_name
The name of the base model for text generation.
- Type:
str
- data_name
The source of the prompts, supporting “Anthropic-harmless” and “Imdb”.
- Type:
str
- file_path
Path to save the generated output in JSON format.
- Type:
str
- top_k
Number of highest probability vocabulary tokens to keep for generation.
- Type:
int
- max_new_tokens
Maximum number of new tokens to generate.
- Type:
int
- generation_config
Configuration settings for text generation.
- Type:
GenerationConfig
Example usage:
# Generate text directly from the original model python gendata.py generate_from_original_model
# Generate text with Monte Carlo sampling from an aligned model python gendata.py generate_from_MC_aligned_model –lam_list=-0.5 –value_list=”humor” –MC_nsamples=50
- device = 'cuda'
- basemodel_name
- data_name
- file_path
- top_k = 50
- max_new_tokens = 50
- generation_config
- generate_from_original_model(batch_size=32)
Generates text continuations directly from the original model using predefined generation configuration.
- Parameters:
batch_size (int, optional) – Number of prompts processed per batch. Defaults to 32.
- Raises:
ValueError – If an unsupported data_name is provided.
- generate_from_MC_aligned_model(lam_list, value_list, MC_nsamples=32, start_index=0, end_index=None)
Samples multiple continuations from each prompt using Monte Carlo sampling and lambda-weighted rewards.
- Parameters:
lam_list (Union[List[float], float]) – Lambda weights for aligning generation with specific rewards.
value_list (Union[List[str], str]) – Values to align the generated text with.
MC_nsamples (int, optional) – Number of Monte Carlo samples per prompt. Defaults to 32.
start_index (int, optional) – Start index of the prompts to process. Defaults to 0.
end_index (Optional[int], optional) – End index of the prompts to process. Defaults to None.
- Raises:
ValueError – If an unsupported data_name is provided.