Skip to content

Commit 7d73473

Browse files
authored
KernelShap Refactoring (#207)
* change: rank_by_importance is now a standalone function * update api and defaults names * Delete tree_shap.py Added this file by mistake. * change: updated categorical_names type signature in KernelShap constructor * doc: improved documentation to KernelShap constructor * change: updated background dataset default size and documentation * feat: added model type argument to KernelShap argument * updated doc to _update_metadata * change: categorical_names and feature_names default to the types that are returned if no user input is provided as opposed to None. * change: ensure expected_value passed to build_explanation is always a list * change: renamed model_type argument to task * doc: fixed docs typos * fix: typos in logging code * change: changed implementation of sum_categories and extended it to work with 3D * change: refactored KernelShap so that output summarisation happens in build_explanation. Factored out summarisation warnings in a separate method * change: improved sum_categories readability and simplified testing code * change: metadata update for result summarisation moved to build_explanation * doc: improved docs to build_explanation * change: updated imports and function names in KernelShap examples * change: updated algorithm overview for KernelShap * fix: linting errors * change: implemented review suggestions * docs: improved documentation to the explain method; changed section ordering in KernelShap docs and fixed typos * change: modified file names for KernelShap implementation to prepare for TreeShap addition * change: Renamed tests that will also be implemented for TreeExplainer * change: Updated fixture names * change: Improved clarity in Theoretical Overview section * change: Fixed redundant warning when summarise_result=False * doc: Improved clarity in Theoretical Overview section * fix: Removed stray files from PR * change: Improved typing for categorical variables start indices and dimensions * Update source in KernelShap documentaion * doc: Improved docstrings and fixed typos in docs * fix: Bug in expected value casting to array * fix: Escaped _ in constructor docstring
1 parent 99850c6 commit 7d73473

File tree

11 files changed

+774
-562
lines changed

11 files changed

+774
-562
lines changed

alibi/api/defaults.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,23 @@
8989
"""
9090

9191
# KernelSHAP
92-
DEFAULT_META_SHAP = {
92+
DEFAULT_META_KERNEL_SHAP = {
9393
"name": None,
9494
"type": ["blackbox"],
95+
"task": None,
9596
"explanations": ["local", "global"],
9697
"params": {}
9798
} # type: dict
9899
"""
99100
Default KernelSHAP metadata.
100101
"""
101102

102-
DEFAULT_DATA_SHAP = {
103+
DEFAULT_DATA_KERNEL_SHAP = {
103104
"shap_values": [],
104105
"expected_value": [],
105106
"link": 'identity',
106-
"categorical_names": None,
107-
"feature_names": None,
107+
"categorical_names": {},
108+
"feature_names": [],
108109
"raw": {
109110
"raw_prediction": None,
110111
"prediction": None,

alibi/explainers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .cem import CEM
99
from .cfproto import CounterFactualProto
1010
from .counterfactual import CounterFactual
11-
from .kernel_shap import KernelShap
11+
from .shap_wrappers import KernelShap
1212

1313
__all__ = ["AnchorTabular",
1414
"DistributedAnchorTabular",

alibi/explainers/kernel_shap.py renamed to alibi/explainers/shap_wrappers.py

Lines changed: 379 additions & 248 deletions
Large diffs are not rendered by default.

alibi/explainers/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def at_adult_explainer(get_adult_dataset, rf_classifier, request):
179179

180180

181181
@pytest.fixture
182-
def mock_ks_explainer(request):
182+
def mock_kernel_shap_explainer(request):
183183
"""
184184
Instantiates a KernelShap explainer with a mock predictor.
185185
"""

alibi/explainers/tests/test_kernel_shap.py renamed to alibi/explainers/tests/test_shap_wrappers.py

Lines changed: 198 additions & 122 deletions
Large diffs are not rendered by default.

doc/source/methods/KernelSHAP.ipynb

Lines changed: 131 additions & 125 deletions
Large diffs are not rendered by default.

doc/source/overview/algorithms.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ instance that would result in a different prediction). [Documentation](../method
4444

4545
**Counterfactual instances**: generate counterfactual examples using a simple loss function. [Documentation](../methods/CF.ipynb), [image classification](../examples/cf_mnist.ipynb).
4646

47-
**Kernel Shapley Additive Explanation (SHAP)**: attribute the change of a model output with respect to a given baseline (e.g., average over a training set) to each of the model features. This is achieved for each feature in turn, by averaging the difference in the model output observed when excluding a feature from the input. The exclusion of a feature is achieved by replacing it with values from the background dataset. [Documentation](../methods/KernelSHAP.ipynb), [continuous data](../examples/kernel_shap_wine_intro.ipynb), [more continous_data](../examples/kernel_shap_wine_lr.ipynb), [categorical data](../examples/kernel_shap_adult_lr.ipynb).
47+
**Kernel Shapley Additive Explanation (KernelSHAP)**: attribute the change of a model output with respect to a given baseline (e.g., average over a training set) to each of the input features. This is achieved for each feature in turn, by averaging the difference in the model output observed when the feature whose contribution is to be estimated is part of a group of "present" input features and the value observed when the feature is excluded from said group. The features that are not "present" (i.e., are missing) are replaced with values from a background dataset. This algorithm can be used to explain regression models. [Documentation](../methods/KernelSHAP.ipynb), [continuous data](../examples/kernel_shap_wine_intro.ipynb), [more continous data](../examples/kernel_shap_wine_lr.ipynb), [categorical data](../examples/kernel_shap_adult_lr.ipynb).
4848

4949
**Prototype Counterfactuals**: generate counterfactuals guided by nearest class prototypes other than the class predicted on the original instance. It can use both an encoder or k-d trees to define the prototypes. This method can speed up the search, especially for black box models, and create interpretable counterfactuals. [Documentation](../methods/CFProto.ipynb), [tabular example](../examples/cfproto_housing.nblink), [tabular example with categorical features](../examples/cfproto_cat_adult_ohe.ipynb), [image classification](../examples/cfproto_mnist.ipynb).
5050

examples/kernel_shap_adult_categorical_preproc.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,8 @@
7676
"import matplotlib.pyplot as plt\n",
7777
"import numpy as np\n",
7878
"import pandas as pd\n",
79-
"import seaborn as sns\n",
8079
"\n",
81-
"from alibi.explainers.kernel_shap import KernelShap\n",
80+
"from alibi.explainers import KernelShap\n",
8281
"from alibi.datasets import fetch_adult\n",
8382
"from scipy.special import logit\n",
8483
"from sklearn.compose import ColumnTransformer\n",
@@ -969,7 +968,7 @@
969968
"name": "python",
970969
"nbconvert_exporter": "python",
971970
"pygments_lexer": "ipython3",
972-
"version": "3.7.3"
971+
"version": "3.7.6"
973972
}
974973
},
975974
"nbformat": 4,

examples/kernel_shap_adult_lr.ipynb

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,8 @@
7676
"import matplotlib.pyplot as plt\n",
7777
"import numpy as np\n",
7878
"import pandas as pd\n",
79-
"import seaborn as sns\n",
8079
"\n",
81-
"from alibi.explainers.kernel_shap import KernelShap\n",
80+
"from alibi.explainers import KernelShap\n",
8281
"from alibi.datasets import fetch_adult\n",
8382
"from scipy.special import logit\n",
8483
"from sklearn.compose import ColumnTransformer\n",
@@ -683,7 +682,7 @@
683682
"metadata": {},
684683
"outputs": [],
685684
"source": [
686-
"def extract_importances(class_idx, beta, feature_names, intercepts=None):\n",
685+
"def get_importance(class_idx, beta, feature_names, intercepts=None):\n",
687686
" \"\"\"\n",
688687
" Retrive and sort abs magnitude of coefficients from model.\n",
689688
" \"\"\"\n",
@@ -702,7 +701,7 @@
702701
"\n",
703702
" return feat_imp, feat_names\n",
704703
"\n",
705-
"def plot_importances(feat_imp, feat_names, class_idx, **kwargs):\n",
704+
"def plot_importance(feat_imp, feat_names, class_idx, **kwargs):\n",
706705
" \"\"\"\n",
707706
" Create a horizontal barchart of feature effects, sorted by their magnitude.\n",
708707
" \"\"\"\n",
@@ -778,10 +777,10 @@
778777
"metadata": {},
779778
"outputs": [],
780779
"source": [
781-
"feat_imp, srt_feat_names = extract_importances(class_idx, \n",
782-
" all_coef, \n",
783-
" perm_feat_names,\n",
784-
" )"
780+
"feat_imp, srt_feat_names = get_importance(class_idx, \n",
781+
" all_coef, \n",
782+
" perm_feat_names,\n",
783+
" )"
785784
]
786785
},
787786
{
@@ -837,13 +836,13 @@
837836
}
838837
],
839838
"source": [
840-
"_, class_0_fig = plot_importances(feat_imp, \n",
841-
" srt_feat_names, \n",
842-
" class_idx,\n",
843-
" left_x=-2.5,\n",
844-
" right_x=3.7,\n",
845-
" eps_factor=12 # controls text distance from end of bar\n",
846-
" )"
839+
"_, class_0_fig = plot_importance(feat_imp, \n",
840+
" srt_feat_names, \n",
841+
" class_idx,\n",
842+
" left_x=-2.5,\n",
843+
" right_x=3.7,\n",
844+
" eps_factor=12 # controls text distance from end of bar\n",
845+
" )"
847846
]
848847
},
849848
{
@@ -2463,7 +2462,7 @@
24632462
"name": "python",
24642463
"nbconvert_exporter": "python",
24652464
"pygments_lexer": "ipython3",
2466-
"version": "3.7.3"
2465+
"version": "3.7.6"
24672466
}
24682467
},
24692468
"nbformat": 4,

examples/kernel_shap_wine_intro.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
"\n",
7676
"import matplotlib.pyplot as plt\n",
7777
"import numpy as np\n",
78-
"import seaborn as sns\n",
7978
"\n",
8079
"from alibi.explainers import KernelShap\n",
8180
"from sklearn import svm\n",
@@ -535,7 +534,9 @@
535534
{
536535
"cell_type": "code",
537536
"execution_count": 16,
538-
"metadata": {},
537+
"metadata": {
538+
"scrolled": true
539+
},
539540
"outputs": [
540541
{
541542
"data": {
@@ -1195,7 +1196,7 @@
11951196
"name": "python",
11961197
"nbconvert_exporter": "python",
11971198
"pygments_lexer": "ipython3",
1198-
"version": "3.7.3"
1199+
"version": "3.7.6"
11991200
}
12001201
},
12011202
"nbformat": 4,

examples/kernel_shap_wine_lr.ipynb

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
"\n",
7676
"import matplotlib.pyplot as plt\n",
7777
"import numpy as np\n",
78-
"import seaborn as sns\n",
7978
"\n",
8079
"from alibi.explainers import KernelShap\n",
8180
"from scipy.special import logit\n",
@@ -332,7 +331,7 @@
332331
" \n",
333332
" return np.all(arr[:-1] <= arr[1:])\n",
334333
"\n",
335-
"def extract_importances(class_idx, beta, feature_names, intercepts=None):\n",
334+
"def get_importance(class_idx, beta, feature_names, intercepts=None):\n",
336335
" \"\"\"\n",
337336
" Retrive and sort abs magnitude of coefficients from model.\n",
338337
" \"\"\"\n",
@@ -353,7 +352,7 @@
353352
"\n",
354353
" return feat_imp, feat_names\n",
355354
"\n",
356-
"def plot_importances(feat_imp, feat_names, **kwargs):\n",
355+
"def plot_importance(feat_imp, feat_names, **kwargs):\n",
357356
" \"\"\"\n",
358357
" Create a horizontal barchart of feature effects, sorted by their magnitude.\n",
359358
" \"\"\"\n",
@@ -413,10 +412,10 @@
413412
"outputs": [],
414413
"source": [
415414
"class_idx = 0\n",
416-
"feat_imp, feat_names = extract_importances(class_idx, \n",
417-
" beta, \n",
418-
" feature_names,\n",
419-
" )"
415+
"feat_imp, feat_names = get_importance(class_idx, \n",
416+
" beta, \n",
417+
" feature_names,\n",
418+
" )"
420419
]
421420
},
422421
{
@@ -438,13 +437,13 @@
438437
}
439438
],
440439
"source": [
441-
"_, class_0_fig = plot_importances(feat_imp, \n",
442-
" feat_names, \n",
443-
" left_x=-1.,\n",
444-
" right_x=1.25,\n",
445-
" xlabel = \"Feature effects (class {})\".format(class_idx),\n",
446-
" ylabel = \"Features\"\n",
447-
" )"
440+
"_, class_0_fig = plot_importance(feat_imp, \n",
441+
" feat_names, \n",
442+
" left_x=-1.,\n",
443+
" right_x=1.25,\n",
444+
" xlabel = \"Feature effects (class {})\".format(class_idx),\n",
445+
" ylabel = \"Features\"\n",
446+
" )"
448447
]
449448
},
450449
{
@@ -487,10 +486,10 @@
487486
"metadata": {},
488487
"outputs": [],
489488
"source": [
490-
"feat_imp, feat_names = extract_importances(1, # class_idx \n",
491-
" beta, \n",
492-
" feature_names,\n",
493-
" )"
489+
"feat_imp, feat_names = get_importance(1, # class_idx \n",
490+
" beta, \n",
491+
" feature_names,\n",
492+
" )"
494493
]
495494
},
496495
{
@@ -519,14 +518,14 @@
519518
}
520519
],
521520
"source": [
522-
"_, class_1_fig = plot_importances(feat_imp, \n",
523-
" feat_names, \n",
524-
" left_x=-1.5,\n",
525-
" right_x=1,\n",
526-
" eps_factor = 5, # controls text distance from end of bar for negative examples\n",
527-
" xlabel = \"Feature effects (class {})\".format(1),\n",
528-
" ylabel = \"Features\"\n",
529-
" )"
521+
"_, class_1_fig = plot_importance(feat_imp, \n",
522+
" feat_names, \n",
523+
" left_x=-1.5,\n",
524+
" right_x=1,\n",
525+
" eps_factor = 5, # controls text distance from end of bar for negative examples\n",
526+
" xlabel = \"Feature effects (class {})\".format(1),\n",
527+
" ylabel = \"Features\"\n",
528+
" )"
530529
]
531530
},
532531
{
@@ -535,10 +534,10 @@
535534
"metadata": {},
536535
"outputs": [],
537536
"source": [
538-
"feat_imp, feat_names = extract_importances(2, # class_idx\n",
539-
" beta, \n",
540-
" feature_names,\n",
541-
" )"
537+
"feat_imp, feat_names = get_importance(2, # class_idx\n",
538+
" beta, \n",
539+
" feature_names,\n",
540+
" )"
542541
]
543542
},
544543
{
@@ -567,14 +566,14 @@
567566
}
568567
],
569568
"source": [
570-
"_, class_2_fig = plot_importances(feat_imp, \n",
571-
" feat_names, \n",
572-
" left_x=-1.25,\n",
573-
" right_x=1,\n",
574-
" xlabel = \"Feature effects (class {})\".format(2),\n",
575-
" ylabel = \"Features\"\n",
569+
"_, class_2_fig = plot_importance(feat_imp, \n",
570+
" feat_names, \n",
571+
" left_x=-1.25,\n",
572+
" right_x=1,\n",
573+
" xlabel = \"Feature effects (class {})\".format(2),\n",
574+
" ylabel = \"Features\"\n",
576575
"# eps_factor = 5.\n",
577-
" )"
576+
" )"
578577
]
579578
},
580579
{
@@ -938,7 +937,7 @@
938937
"name": "python",
939938
"nbconvert_exporter": "python",
940939
"pygments_lexer": "ipython3",
941-
"version": "3.7.3"
940+
"version": "3.7.6"
942941
}
943942
},
944943
"nbformat": 4,

0 commit comments

Comments
 (0)