Skip to content

Commit c1927a4

Browse files
committed
resolving bug of get_selected_index
1 parent d9b8463 commit c1927a4

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

documentation.ipynb

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
"output_type": "stream",
134134
"name": "stdout",
135135
"text": [
136-
"\nTraining started for batch: 1\n- number of soft clauses: 93\n- number of Boolean variables: 157\n- number of hard and soft clauses: 863\n\n\nBatch tarining complete\n- number of literals in the rule: 2\n- number of training errors: 3 out of 49\n\nTraining started for batch: 2\n- number of soft clauses: 95\n- number of Boolean variables: 161\n- number of hard and soft clauses: 890\n\n\nBatch tarining complete\n- number of literals in the rule: 4\n- number of training errors: 1 out of 51\n"
136+
"(100, 11)\n\nTraining started for batch: 1\n- number of soft clauses: 93\n- number of Boolean variables: 157\n- number of hard and soft clauses: 863\n\n\nBatch tarining complete\n- number of literals in the rule: 2\n- number of training errors: 3 out of 49\n\nTraining started for batch: 2\n- number of soft clauses: 95\n- number of Boolean variables: 161\n- number of hard and soft clauses: 890\n\n\nBatch tarining complete\n- number of literals in the rule: 4\n- number of training errors: 1 out of 51\n(100, 11)\n"
137137
]
138138
}
139139
],
@@ -150,7 +150,7 @@
150150
},
151151
{
152152
"cell_type": "code",
153-
"execution_count": 6,
153+
"execution_count": 8,
154154
"metadata": {
155155
"tags": []
156156
},
@@ -168,8 +168,7 @@
168168
"print(classification_report(y_train, model.predict(X_train), target_names=['0','1']))\n",
169169
"print()\n",
170170
"print(\"test report: \")\n",
171-
"print(classification_report(y_test, model.predict(X_test), target_names=['0','1']))\n",
172-
"\n"
171+
"print(classification_report(y_test, model.predict(X_test), target_names=['0','1']))\n"
173172
]
174173
},
175174
{
@@ -181,7 +180,7 @@
181180
},
182181
{
183182
"cell_type": "code",
184-
"execution_count": 7,
183+
"execution_count": 9,
185184
"metadata": {
186185
"tags": []
187186
},
@@ -212,7 +211,7 @@
212211
},
213212
{
214213
"cell_type": "code",
215-
"execution_count": 8,
214+
"execution_count": 10,
216215
"metadata": {
217216
"tags": []
218217
},
@@ -223,7 +222,7 @@
223222
},
224223
{
225224
"cell_type": "code",
226-
"execution_count": 9,
225+
"execution_count": 13,
227226
"metadata": {
228227
"tags": []
229228
},
@@ -232,7 +231,7 @@
232231
"output_type": "stream",
233232
"name": "stdout",
234233
"text": [
235-
"training report: \n precision recall f1-score support\n\n 0 0.93 0.98 0.96 65\n 1 0.97 0.86 0.91 35\n\n accuracy 0.94 100\n macro avg 0.95 0.92 0.93 100\nweighted avg 0.94 0.94 0.94 100\n\n\ntest report: \n precision recall f1-score support\n\n 0 0.97 1.00 0.99 35\n 1 1.00 0.93 0.97 15\n\n accuracy 0.98 50\n macro avg 0.99 0.97 0.98 50\nweighted avg 0.98 0.98 0.98 50\n\n\n( not sepal length >= 7.05 AND petal width = (0.8 - 1.75) ) OR \n( petal length = (2.45 - 4.75))\n\n[[-2, 9], [6]]\n"
234+
"training report: \n precision recall f1-score support\n\n 0 0.93 0.98 0.96 65\n 1 0.97 0.86 0.91 35\n\n accuracy 0.94 100\n macro avg 0.95 0.92 0.93 100\nweighted avg 0.94 0.94 0.94 100\n\n\ntest report: \n precision recall f1-score support\n\n 0 0.97 1.00 0.99 35\n 1 1.00 0.93 0.97 15\n\n accuracy 0.98 50\n macro avg 0.99 0.97 0.98 50\nweighted avg 0.98 0.98 0.98 50\n\n\nRule:->\n( not sepal length >= 7.05 AND petal width = (0.8 - 1.75) ) OR \n( petal length = (2.45 - 4.75))\n\nOriginal features:\n['sepal length < 5.45', 'sepal length = (5.45 - 7.05)', 'sepal length >= 7.05', 'sepal width < 2.95', 'sepal width >= 2.95', 'petal length < 2.45', 'petal length = (2.45 - 4.75)', 'petal length >= 4.75', 'petal width < 0.8', 'petal width = (0.8 - 1.75)', 'petal width >= 1.75']\n\nIn the learned rule, show original index in the feature list with phase (1: original, -1: complemented)\n[[(2, -1), (9, 1)], [(6, 1)]]\n"
236235
]
237236
}
238237
],
@@ -246,9 +245,11 @@
246245
"print(\"test report: \")\n",
247246
"print(classification_report(y_test, model.predict(X_test), target_names=['0','1']))\n",
248247
"\n",
249-
"print()\n",
248+
"print(\"\\nRule:->\")\n",
250249
"print(model.get_rule(features))\n",
251-
"print()\n",
250+
"print(\"\\nOriginal features:\")\n",
251+
"print(features)\n",
252+
"print(\"\\nIn the learned rule, show original index in the feature list with phase (1: original, -1: complemented)\")\n",
252253
"print(model.get_selected_column_index())"
253254
]
254255
},
@@ -266,7 +267,7 @@
266267
},
267268
{
268269
"cell_type": "code",
269-
"execution_count": 10,
270+
"execution_count": 14,
270271
"metadata": {},
271272
"outputs": [],
272273
"source": [
@@ -275,7 +276,7 @@
275276
},
276277
{
277278
"cell_type": "code",
278-
"execution_count": 11,
279+
"execution_count": 15,
279280
"metadata": {
280281
"tags": []
281282
},
@@ -284,7 +285,7 @@
284285
"output_type": "stream",
285286
"name": "stdout",
286287
"text": [
287-
"training report: \n precision recall f1-score support\n\n 0 0.96 0.98 0.97 65\n 1 0.97 0.91 0.94 35\n\n accuracy 0.96 100\n macro avg 0.96 0.95 0.96 100\nweighted avg 0.96 0.96 0.96 100\n\n\ntest report: \n precision recall f1-score support\n\n 0 0.97 1.00 0.99 35\n 1 1.00 0.93 0.97 15\n\n accuracy 0.98 50\n macro avg 0.99 0.97 0.98 50\nweighted avg 0.98 0.98 0.98 50\n\n"
288+
"training report: \n precision recall f1-score support\n\n 0 0.91 0.31 0.46 65\n 1 0.42 0.94 0.58 35\n\n accuracy 0.53 100\n macro avg 0.67 0.63 0.52 100\nweighted avg 0.74 0.53 0.50 100\n\n\ntest report: \n precision recall f1-score support\n\n 0 0.90 0.26 0.40 35\n 1 0.35 0.93 0.51 15\n\n accuracy 0.46 50\n macro avg 0.62 0.60 0.45 50\nweighted avg 0.73 0.46 0.43 50\n\n"
288289
]
289290
}
290291
],
@@ -310,7 +311,7 @@
310311
},
311312
{
312313
"cell_type": "code",
313-
"execution_count": 12,
314+
"execution_count": 16,
314315
"metadata": {
315316
"tags": []
316317
},
@@ -319,7 +320,7 @@
319320
"output_type": "stream",
320321
"name": "stdout",
321322
"text": [
322-
"Learned rule is: \n\nAn Iris flower is predicted as Iris Versicolor if\n[ ( sepal width >= 2.95 + petal length = (2.45 - 4.75) + petal width = (0.8 - 1.75) + not sepal length >= 7.05 )>= 3 ] +\n[ ( )>= 0 ] >= 2\n\nThrehosld on clause: 2\nThreshold on literals: (this is a list where the entries denote threholds on literals on all clauses)\n[3, 0]\n"
323+
"Learned rule is: \n\nAn Iris flower is predicted as Iris Versicolor if\n[ ( petal width = (0.8 - 1.75) + not sepal length >= 7.05 )>= 1 ] +\n[ ( sepal width >= 2.95 + petal length = (2.45 - 4.75) )>= 1 ] >= 2\n\nThrehosld on clause: 2\nThreshold on literals: (this is a list where the entries denote threholds on literals on all clauses)\n[1, 1]\n"
323324
]
324325
}
325326
],
@@ -329,7 +330,7 @@
329330
"print(\"An Iris flower is predicted as Iris Versicolor if\")\n",
330331
"print(rule)\n",
331332
"print(\"\\nThrehosld on clause:\", model.get_threshold_clause())\n",
332-
"print(\"Threshold on literals: (this is a list where the entries denote threholds on literals on all clauses)\")\n",
333+
"print(\"Threshold on literals: (this is a list where entries denote threholds on literals on all clauses)\")\n",
333334
"print(model.get_threshold_literal())"
334335
]
335336
},

pyrulelearn/imli.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,14 @@ def get_selected_column_index(self):
134134
for index_list in temp:
135135
each_level_index = []
136136
for index in index_list:
137+
phase = 1
137138
actual_feature_len = int(self.numFeatures/2)
138139
if(self.ruleType == "DNF"):
139140
index = index - actual_feature_len if index >= actual_feature_len else index + actual_feature_len
140141
if(index >= actual_feature_len):
141-
index = -1 * (index - actual_feature_len)
142-
each_level_index.append(index)
142+
index = index - actual_feature_len
143+
phase = -1
144+
each_level_index.append((index, phase))
143145
result.append(each_level_index)
144146

145147
return result
@@ -775,6 +777,7 @@ def __generateWcnfFile(self, AMatrix, yVector, xSize, WCNFFile,
775777
def get_rule(self, features):
776778

777779
if(2 * len(features) == self.numFeatures):
780+
features = [str(feature) for feature in features]
778781
features += ["not " + str(feature) for feature in features]
779782

780783
assert len(features) == self.numFeatures

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
setup(
1111
name = 'pyrulelearn',
1212
packages = ['pyrulelearn'],
13-
version = 'v1.0.8',
13+
version = 'v1.0.9',
1414
license='MIT',
1515
description = 'This library can be used to generate interpretable classification rules expressed as CNF/DNF and relaxed-CNF',
1616
long_description=long_description,
1717
long_description_content_type='text/markdown',
1818
author = 'Bishwamittra Ghosh',
1919
author_email = 'bishwamittra.ghosh@gmail.com',
2020
url = 'https://github.com/meelgroup/MLIC',
21-
download_url = 'https://github.com/meelgroup/MLIC/archive/v1.0.7.tar.gz',
21+
download_url = 'https://github.com/meelgroup/MLIC/archive/v1.0.8.tar.gz',
2222
keywords = ['Classification Rules', 'Interpretable Rules', 'CNF Classification Rules', 'DNF Classification Rules','MaxSAT-based Rule Learning'], # Keywords that define your package best
2323
classifiers=[
2424
'Development Status :: 3 - Alpha',

0 commit comments

Comments
 (0)