-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathviz_arm_means_single_exp.py
More file actions
163 lines (125 loc) · 7.15 KB
/
viz_arm_means_single_exp.py
File metadata and controls
163 lines (125 loc) · 7.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import numpy as np
import matplotlib.pyplot as plt
import os
import matplotlib.ticker as ticker
parser = argparse.ArgumentParser(description="Visualize arm means for a single experiment.")
parser.add_argument("--prolific_results_folder", type=str, required=True, help="Path to the folder containing Prolific data.")
parser.add_argument("--npz_file", type=str, required=True, help="Path to the .npz file containing arm means.")
parser.add_argument("--ci", action="store_true", help="Whether to plot confidence intervals or std.")
parser.add_argument("--iter", type=int, default=None, help="Iteration to visualize (default = n_steps).")
parser.add_argument("--ylim", type=float, nargs=2, default=None, help="Y-axis limits for the plot (optional).")
parser.add_argument("--min-arm", type=int, required=True, help="Minimum arm value for the x-axis.")
parser.add_argument("--bn-pos", type=int, default=0, help=" x-axis to place the label b=n")
parser.add_argument("--b0-pos", type=int, default=0, help=" x-axis to place the label b=0")
parser.add_argument("--nyticks", type=int, default=None, help="Number of y-ticks to display (optional).")
parser.add_argument("--add-random-baseline", action="store_true",default=False, help="Whether to add a random baseline to the plot.")
## Reading arguments
args = parser.parse_args()
npz_file = args.npz_file
ci = args.ci
prolific_results_folder = args.prolific_results_folder
ITER = args.iter
ylim = args.ylim
min_arm = args.min_arm
bn_pos = args.bn_pos
b0_pos = args.b0_pos if args.b0_pos > min_arm else min_arm
bn_pos = args.bn_pos if args.bn_pos > min_arm else min_arm
nyticks = args.nyticks
random_baseline = args.add_random_baseline
## Creating paths and loading results
EXP_RES_PATH = os.path.join(prolific_results_folder, 'bandit', npz_file)
SAVE_PATH = EXP_RES_PATH[:-4]
EXP_DETAILS = '-'.join(EXP_RES_PATH.split(os.sep)[-1].split('-')[2:5])
# Load results
results = np.load(EXP_RES_PATH, allow_pickle=True)
## Reading parameters from results
# Iteration to viz, simulations and horizon
N_SIMS = int(results['num_sims'])
N_STEPS = int(results['horizon'])
# check if ITER is None or greater than N_STEPS
if ITER is None:
ITER = N_STEPS
elif ITER > N_STEPS:
print(f"Warning: ITER ({ITER}) is greater than N_STEPS ({N_STEPS}). Setting ITER to N_STEPS.")
ITER = N_STEPS
# Statistics about the arms for that iterations
arms_means = np.array(results['arm_means'][:,:ITER,min_arm:]) # iterations x simulations x arms
arm_means_avg = np.mean(arms_means,axis=0) # arms x iterations: mean over diff simulations
arm_means_std = np.std(arms_means,axis=0) # arms x iterations: std over diff iterations
arm_means_algorithm = np.array(results['arm_means'][:,:ITER,0])
arm_means_algorithm_avg = np.mean(arm_means_algorithm,axis=0) # iterations: mean over diff simulations
arm_means_algorithm_std = np.std(arm_means_algorithm,axis=0) # iterations: std over diff simulations
#arm ids
arms = results['arms']
#filter out arms lower than min_arm
arms = arms[arms >= min_arm]
#get optimal arm
final_arm_means_avg_over_sim = arms_means[:,ITER-1,:].mean(axis=0)
id_opt_arm = np.argmax(final_arm_means_avg_over_sim)
opt_arm = arms[id_opt_arm]
opt_util = final_arm_means_avg_over_sim[id_opt_arm]
if random_baseline:
random_results = np.load(os.path.join(prolific_results_folder, 'bandit',
'bandit-results-h1000-s100-RANDOM_BASELINE_ALL-20250617-161344.npz'),
allow_pickle=True)
# Random baseline means
iter_rand = random_results['horizon']
iter_rand = iter_rand if ITER is None or ITER > iter_rand else ITER
random_arm_means = np.array(random_results['arm_means'][:,:iter_rand,min_arm:]) # iterations x simulations x arms
random_arm_means_avg = np.mean(random_arm_means,axis=0) # arms x iterations: mean over diff simulations
random_arm_means_std = np.std(random_arm_means,axis=0) # arms x iterations: std over diff iterations
random_arm_means_algorithm = np.array(random_results['arm_means'][:,:iter_rand,0])
random_arm_means_algorithm_avg = np.mean(random_arm_means_algorithm,axis=0) # iterations: mean over diff simulations
random_arm_means_algorithm_std = np.std(random_arm_means_algorithm,axis=0) # iterations: std over diff simulations
##*#################
##* PLOTTING FIGURE
##*#################
cmap = plt.get_cmap('Set2')
plt.style.use('science')
f, ax = plt.subplots(figsize=(5, 3))
tt = ITER - 1
#ax.errorbar(arms, arm_means_avg[tt,:], yerr=arm_means_std[tt,:], fmt='o-', label=arm, capsize=3, color=cmap(0),alpha=1)
ax.plot(arms, arm_means_avg[tt,:], marker='o', color=cmap(0), label='Mean Reward', alpha=1)
ci = 1.96 * arm_means_std[tt,:] / np.sqrt(N_SIMS) # 95% CI
ax.fill_between(arms, arm_means_avg[tt,:] - ci, arm_means_avg[tt,:] + ci,
color=cmap(0), alpha=0.2, linewidth=0)
ax.scatter(opt_arm, opt_util, color=cmap(4), zorder=5, s=70, edgecolors='black', alpha=1)
ax.annotate(f"Optimal",
xy=(opt_arm-0.2, opt_util+0.0005),
xytext=(opt_arm-3, opt_util+0.005),
arrowprops=dict(facecolor='black', width=0.5, headwidth=4, headlength=4),
fontsize=12, color='black', ha='center')
ax.axhline(arm_means_algorithm_avg[tt], color=cmap(1), linestyle='--', zorder=0)
ax.text(b0_pos, arm_means_algorithm_avg[tt] + 0.0005, f"$b=0$", color='black', fontsize=12, ha='left', va='bottom')
ax.axhline(arm_means_avg[tt,-1], color=cmap(1), linestyle='--', zorder=0)
ax.text(bn_pos, arm_means_avg[tt,-1] + 0.0005, f"$b=n$", color='black', fontsize=12, ha='left', va='bottom')
if random_baseline:
ax.plot(arms, random_arm_means_avg[tt,:], marker='o', color=cmap(2), label='Random Baseline', alpha=.4)
ci = 1.96 * random_arm_means_std[tt,:] / np.sqrt(random_results['num_sims']) # 95% CI
ax.fill_between(arms, random_arm_means_avg[tt,:] - ci, random_arm_means_avg[tt,:] + ci,
color=cmap(2), alpha=0.1, linewidth=0)
if ylim is not None:
ax.set_ylim(ylim[0], ylim[1])
if nyticks is None:
nums = int(np.round((ylim[1] - ylim[0])*100)) + 1
nums = int(np.ceil(nums/2)) if nums > 8 else nums
#print(ylim[0]-ylim[1], nums)
yticks = np.linspace(ylim[0], ylim[1], num=nyticks) # or another appropriate number
ax.set_yticks(yticks)
ax.set_xlabel("$b$", fontsize=14)
ax.set_ylabel("Expected Utility", fontsize=14)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.xaxis.set_minor_locator(ticker.NullLocator()) #ax.xaxis.set_minor_locator(ticker.MultipleLocator(1))
ax.tick_params(axis='x', which='major', bottom=True, length=5, labelsize=14)
#ax.tick_params(axis='x', which='minor', bottom=True, length=3, labelsize=0) # labelsize=0 hides minor labels
ax.tick_params(top=False, right=False, which='both')
ax.tick_params(axis='y', which='major', left=True, length=6, labelsize=14)
ax.tick_params(axis='y', which='minor', left=True, length=2, labelsize=0) # labelsize=0 hides minor labels
plt.tight_layout()
## Saving the figure
if not os.path.exists(SAVE_PATH):
os.makedirs(SAVE_PATH)
plt.savefig(os.path.join(SAVE_PATH, f'arms-means-h{ITER}.pdf'), bbox_inches='tight', pad_inches=0.1, dpi=300)