-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerative_model_sampling.py
More file actions
206 lines (175 loc) · 7.48 KB
/
generative_model_sampling.py
File metadata and controls
206 lines (175 loc) · 7.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# This module implements sampling from generative models and post-processing of datasets
import pandas as pd
import numpy as np
import torch
import random
from tools.generative_model_estimation import fit_model
def round_func(val, data_info):
"""Rounds generated value to the required number of decimal points
:param val: value
:param data_info: number of decimal points to round
:return val: rounded value
"""
if data_info != None:
val = round(val, data_info)
return val
def lim_func(val, data_info_high, data_info_low):
"""Limits generated value
:param val: value
:param data_info_high: upper limit
:return data_info_low: lower limit
"""
if data_info_high != None:
if val > data_info_high:
val = data_info_high
if data_info_low != None:
if val < data_info_low:
val = data_info_low
return val
def post_processing(sampled_data, data):
"""Post-processing function
:param sampled_data: synthetic dataset
:param data: sampling parameters
:return sampled_data: post-processed synthetic dataset
"""
for col in sampled_data.columns:
sampled_data[col] = sampled_data[col].apply(lim_func, args=(data["lim"][col]["high"], data["lim"][col]["low"]))
sampled_data[col] = sampled_data[col].apply(round_func, args=[data["round"][col]])
return sampled_data
def simple_sample_procedure(model, sample_len, seed_list, cols, scaling, data):
"""Simple sampling function
:param model: generative model
:param sample_len: number of samples in synthetic dataset
:param seed_list: list of seeds for sampling
:param cols: dataset columns
:param scaling: scaler for synthetic dataset
:param data: sampling parameters
:return sampled_data_list: list of synthetic datasets
"""
sampled_data_list = []
for seed in seed_list:
sampled_data = model.sample(n_samples=sample_len, random_state=seed)
if scaling:
sampled_data = scaling.inverse_transform(sampled_data)
sampled_data = pd.DataFrame(sampled_data, columns=cols)
sampled_data = post_processing(sampled_data, data)
sampled_data_list.append(sampled_data)
return sampled_data_list
def sample_stats(kde, size, seed):
"""Sampling from Statsmodel's KDE
:param kde: KDE
:param size: number of samples in synthetic dataset
:param seed: seed
:return sampled_data: generated data
"""
rng = np.random.RandomState(seed)
n, d = kde.data.shape
indices = rng.randint(0, n, size)
cov = np.diag(kde.bw)**2
means = kde.data[indices, :]
norm = rng.multivariate_normal(np.zeros(d), cov, size)
sampled_data = np.transpose(means + norm).T
return sampled_data
def simple_sample_stats_procedure(model, sample_len, seed_list, cols, scaling, data):
"""Sampling synthetic datasets from Statsmodel's KDE
:param model: generative model
:param sample_len: number of samples in synthetic dataset
:param seed_list: list of seeds for sampling
:param cols: dataset columns
:param scaling: scaler for synthetic dataset
:param data: sampling parameters
:return sampled_data_list: list of synthetic datasets
"""
sampled_data_list = []
for seed in seed_list:
sampled_data = sample_stats(model, sample_len, seed)
if scaling:
sampled_data = scaling.inverse_transform(sampled_data)
sampled_data = pd.DataFrame(sampled_data, columns=cols)
if data:
sampled_data = post_processing(sampled_data, data)
sampled_data_list.append(sampled_data)
return sampled_data_list
def gmm_sample_procedure(model, sample_len, cols, scaling, num_samples, data):
"""Sampling from GMM
:param model: generative model
:param sample_len: number of samples in synthetic dataset
:param cols: dataset columns
:param scaling: scaler for synthetic dataset
:param num_samples: number of synthetic datasets
:param data: sampling parameters
:return sampled_data_list: list of synthetic datasets
"""
sampled_data_list = []
n_samples = model.sample(sample_len*num_samples)[0]
for i in range(num_samples):
sampled_data = n_samples[(i*sample_len):((i+1)*sample_len)]
if scaling:
sampled_data = scaling.inverse_transform(sampled_data)
sampled_data = pd.DataFrame(sampled_data, columns=cols)
sampled_data = post_processing(sampled_data, data)
sampled_data_list.append(sampled_data)
return sampled_data_list
def sample_sdv_procedure(model, sample_len, seed_list, cols, scaling, data):
"""Sampling from SDV library model
:param model: generative model
:param sample_len: number of samples in synthetic dataset
:param seed_list: list of seeds for sampling
:param cols: dataset columns
:param scaling: scaler for synthetic dataset
:param data: sampling parameters
:return sampled_data_list: list of synthetic datasets
"""
sampled_data_list = []
for seed in seed_list:
np.random.seed(seed)
torch.manual_seed(seed)
sampled_data = model.sample(sample_len)
if scaling:
sampled_data = scaling.inverse_transform(sampled_data)
sampled_data = pd.DataFrame(sampled_data, columns=cols)
sampled_data = post_processing(sampled_data, data)
sampled_data_list.append(sampled_data)
return sampled_data_list
def get_sampled_data(model, sample_len, seed_list, method, cols, scaling, data):
"""Calls a sampling function
:param model: generative model
:param sample_len: number of samples in synthetic dataset
:param seed_list: list of seeds for sampling
:param method: generative model name
:param cols: dataset columns
:param scaling: scaler for synthetic dataset
:param data: sampling parameters
:return sampled_data_list: list of synthetic datasets
"""
if method in ["sklearn_kde", "awkde"]:
sampled_data_list = simple_sample_procedure(model, sample_len, seed_list, cols, scaling, data)
elif method in ["kde_cv_ml", "kde_cv_ls"]:
sampled_data_list = simple_sample_stats_procedure(model, sample_len, seed_list, cols, scaling, data)
elif method in ["gmm", "bayesian_gmm"]:
sampled_data_list = gmm_sample_procedure(model, sample_len, cols, scaling, len(seed_list), data)
elif method in ["ctgan", "copula", "copulagan", "tvae"]:
sampled_data_list = sample_sdv_procedure(model, sample_len, seed_list, cols, scaling, data)
return sampled_data_list
def get_sample_for_pairwise_plot(gen_model, ind_data, ds):
"""Chooses generative model with the lowest ROC AUC and samples a dataset
:param gen_model: generative algorithm name
:param ind_data: dataset similarity indicators for the generative algorithm
:param ds: dataset parameters
:return sampled_data: synthetic dataset
"""
iter = 0
min_roc_auc = 1
for iter_num in ind_data:
min_roc_auc_iter = min(ind_data[iter_num]["c2st_roc_auc"])
if min_roc_auc_iter < min_roc_auc:
min_roc_auc = min_roc_auc_iter
iter = iter_num
sample_num = np.argmin(ind_data[iter]["c2st_roc_auc"])
random.seed(42)
seed_val = random.sample(list(range(100000)), 100)
seed_val_cv = seed_val[:50][iter]
seed_val_sampling = seed_val[sample_num]
fitted_model = fit_model(gen_model, ds["data_scaled"], seed_val_cv)
sampled_data = get_sampled_data(fitted_model, ds["len"], [seed_val_sampling], gen_model, ds["cols"], ds["scaler"], ds)[0]
return sampled_data