gendata

Classes

TextGeneration

Generates text continuations based on existing prompts, either using a pretrained model or an aligned model with Monte Carlo sampling.

Module Contents

class gendata.TextGeneration(basemodel_name, data_name, save_directory='results')

Generates text continuations based on existing prompts, either using a pretrained model or an aligned model with Monte Carlo sampling.

device

The device for model inference, typically ‘cuda’ if available.

Type:

str

basemodel_name

The name of the base model for text generation.

Type:

str

data_name

The source of the prompts, supporting “Anthropic-harmless” and “Imdb”.

Type:

str

file_path

Path to save the generated output in JSON format.

Type:

str

top_k

Number of highest probability vocabulary tokens to keep for generation.

Type:

int

max_new_tokens

Maximum number of new tokens to generate.

Type:

int

generation_config

Configuration settings for text generation.

Type:

GenerationConfig

Example usage:

# Generate text directly from the original model python gendata.py generate_from_original_model

# Generate text with Monte Carlo sampling from an aligned model python gendata.py generate_from_MC_aligned_model –lam_list=-0.5 –value_list=”humor” –MC_nsamples=50

device = 'cuda'
basemodel_name
data_name
file_path
top_k = 50
max_new_tokens = 50
generation_config
generate_from_original_model(batch_size=32)

Generates text continuations directly from the original model using predefined generation configuration.

Parameters:

batch_size (int, optional) – Number of prompts processed per batch. Defaults to 32.

Raises:

ValueError – If an unsupported data_name is provided.

generate_from_MC_aligned_model(lam_list, value_list, MC_nsamples=32, start_index=0, end_index=None)

Samples multiple continuations from each prompt using Monte Carlo sampling and lambda-weighted rewards.

Parameters:
  • lam_list (Union[List[float], float]) – Lambda weights for aligning generation with specific rewards.

  • value_list (Union[List[str], str]) – Values to align the generated text with.

  • MC_nsamples (int, optional) – Number of Monte Carlo samples per prompt. Defaults to 32.

  • start_index (int, optional) – Start index of the prompts to process. Defaults to 0.

  • end_index (Optional[int], optional) – End index of the prompts to process. Defaults to None.

Raises:

ValueError – If an unsupported data_name is provided.