Skip to content

Commit 6527a0a

Browse files
committed
Fix dimension handling in WasserstainLIME2
1 parent dc70ea0 commit 6527a0a

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

xwhy/smile_tabular.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,36 +28,41 @@ def WasserstainLIME2(X_input, model, num_perturb = 500, L_num_perturb = 100, ker
2828

2929
X_input = (X_input - np.mean(X_input,axis=0)) / np.std(X_input,axis=0) #Standarization of data
3030

31-
X_lime = np.random.normal(0,1,size=(num_perturb,X_input.shape[0]))
32-
33-
Xi2 = np.zeros((L_num_perturb,X_input.shape[0]))
34-
35-
for jj in range(X_input.shape[0]):
36-
Xi2[:,jj] = X_input[jj] + np.random.normal(0,0.05,L_num_perturb)
31+
# number of features for the single input instance
32+
n_features = X_input.shape[1]
33+
34+
# generate random perturbations around the standardized input
35+
X_lime = np.random.normal(0, 1, size=(num_perturb, n_features))
36+
37+
# create local perturbations for computing the Wasserstein distances
38+
Xi2 = np.zeros((L_num_perturb, n_features))
39+
40+
for jj in range(n_features):
41+
Xi2[:, jj] = X_input[0, jj] + np.random.normal(0, 0.05, L_num_perturb)
3742

3843
y_lime2 = np.zeros((num_perturb,1))
3944
WD = np.zeros((num_perturb,1))
4045
weights2 = np.zeros((num_perturb,1))
4146

4247
for ind, ii in enumerate(X_lime):
43-
48+
4449
df2 = pd.DataFrame()
45-
46-
for jj in range(X_input.shape[0]):
47-
temp1 = ii[jj] + np.random.normal(0,0.3,L_num_perturb)
50+
51+
for jj in range(n_features):
52+
temp1 = ii[jj] + np.random.normal(0, 0.3, L_num_perturb)
4853
df2[len(df2.columns)] = temp1
4954

5055
temp3 = model.predict(df2.to_numpy())
5156

5257
y_lime2[ind] = np.mean(temp3) # For classification: np.argmax(np.bincount(temp3))
5358

54-
WD1 = np.zeros((X_input.shape[0],1))
59+
WD1 = np.zeros((n_features, 1))
5560

5661
df2 = df2.to_numpy()
57-
58-
for kk in range(X_input.shape[0]):
62+
63+
for kk in range(n_features):
5964
#print( df2.shape)
60-
WD1[kk] = Wasserstein_Dist(Xi2[:,kk], df2[:,kk])
65+
WD1[kk] = Wasserstein_Dist(Xi2[:, kk], df2[:, kk])
6166

6267
#print(WD1)
6368
#print(ind)

0 commit comments

Comments
 (0)