plot_reward_dist
Attributes
Functions
|
Generate histograms for each reward type in a JSON file, visualizing reward distributions. |
Plot weighted and unweighted histograms for reward distributions, showing the effects of MAP alignment. |
|
|
Plot a histogram for the 'positive' column in a JSON file, showing value distribution. |
Module Contents
- plot_reward_dist.plot_hist(json_file: str) None
Generate histograms for each reward type in a JSON file, visualizing reward distributions.
This function reads a JSON file containing calculated rewards for various sentences, converts the data into a DataFrame, and plots histograms for each reward category. Each histogram represents the frequency distribution of rewards for a specific category.
- Parameters:
json_file (str) – Path to the JSON file containing reward data.
Example
>>> json_file = "results/opt1.3b-Anthropic-harmless.json" >>> plot_hist(json_file)
- plot_reward_dist.plot_weighted_unweighted_histograms(file_path: str, values_to_evaluate: list[str], values_to_align: list[str], lam: list[float], subplot_names: list[str], save_path: str) None
Plot weighted and unweighted histograms for reward distributions, showing the effects of MAP alignment.
This function reads a JSON file with reward data, applies MAP alignment using specified lambda weights, and generates histograms for each reward type before and after alignment. The histograms allow comparison between original and MAP-aligned distributions for each reward type.
- Parameters:
file_path (str) – Path to the JSON file containing reward data.
values_to_evaluate (list[str]) – Reward types to evaluate and plot (e.g., [“humor”, “gpt2-helpful”]).
values_to_align (list[str]) – Reward types to align based on the lambda weights.
lam (list[float]) – Lambda values used to adjust reward weights in alignment.
subplot_names (list[str]) – Names to use for subplots, corresponding to each value in values_to_evaluate.
save_path (str) – Path to save the resulting PDF file of histograms.
Example
>>> file_path = "results/llama2_chat-Anthropic-harmless.json" >>> values_to_evaluate = ["humor", "gpt2-helpful", "gpt2-harmless"] >>> values_to_align = ["humor", "gpt2-helpful", "gpt2-harmless", "diversity", "coherence", "perplexity"] >>> lam = [5.942, 2.432, 2.923, 0.006, 0.011, 0.147] >>> subplot_names = ["Humor", "Helpfulness", "Harmlessness"] >>> save_path = "results/fig_hist_llama2chat_80.pdf" >>> plot_weighted_unweighted_histograms(file_path, values_to_evaluate, values_to_align, lam, subplot_names, save_path)
- plot_reward_dist.plot_hist_positive(json_file: str) None
Plot a histogram for the ‘positive’ column in a JSON file, showing value distribution.
This function reads a JSON file with a ‘positive’ column, applies an exponential transformation to the values, and creates a histogram. It’s used to visualize the distribution of transformed ‘positive’ values across the dataset.
- Parameters:
json_file (str) – Path to the JSON file containing a ‘positive’ column.
Example
>>> json_file = "results/opt1.3b-positive_values.json" >>> plot_hist_positive(json_file)
- plot_reward_dist.file_path = 'results/llama2_chat-Anthropic-harmless.json'