Skip to content

Commit 8cff8b2

Browse files
committed
Ability to compare multiple models in Insights
1 parent 03f89a5 commit 8cff8b2

22 files changed

+533
-344
lines changed

captum/insights/attr_vis/app.py

Lines changed: 95 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _get_context():
7676

7777

7878
VisualizationOutput = namedtuple(
79-
"VisualizationOutput", "feature_outputs actual predicted active_index"
79+
"VisualizationOutput", "feature_outputs actual predicted active_index model_index"
8080
)
8181
Contribution = namedtuple("Contribution", "name percent")
8282
SampleCache = namedtuple("SampleCache", "inputs additional_forward_args label")
@@ -149,11 +149,8 @@ def __init__(
149149
r"""
150150
Args:
151151
152-
models (torch.nn.module): PyTorch module (model) for attribution
152+
models (torch.nn.module): One or more PyTorch modules (models) for attribution
153153
visualization.
154-
We plan to support visualizing and comparing multiple models
155-
in the future, but currently this supports only a single
156-
model.
157154
classes (list of string): List of strings corresponding to the names of
158155
classes for classification.
159156
features (list of BaseFeature): List of BaseFeatures, which correspond
@@ -195,6 +192,7 @@ class scores.
195192
self.classes = classes
196193
self.features = features
197194
self.dataset = dataset
195+
self.models = models
198196
self.attribution_calculation = AttributionCalculation(
199197
models, classes, features, score_func, use_label_for_attr
200198
)
@@ -203,12 +201,16 @@ class scores.
203201
self._dataset_iter = iter(dataset)
204202

205203
def _calculate_attribution_from_cache(
206-
self, index: int, target: Optional[Tensor]
204+
self, input_index: int, model_index: int, target: Optional[Tensor]
207205
) -> Optional[VisualizationOutput]:
208-
c = self._outputs[index][1]
206+
c = self._outputs[input_index][1]
209207
return self._calculate_vis_output(
210-
c.inputs, c.additional_forward_args, c.label, torch.tensor(target)
211-
)
208+
c.inputs,
209+
c.additional_forward_args,
210+
c.label,
211+
torch.tensor(target),
212+
model_index,
213+
)[0]
212214

213215
def _update_config(self, settings):
214216
self._config = FilterConfig(
@@ -344,67 +346,97 @@ def _should_keep_prediction(
344346
return True
345347

346348
def _calculate_vis_output(
347-
self, inputs, additional_forward_args, label, target=None
348-
) -> Optional[VisualizationOutput]:
349-
actual_label_output = None
350-
if label is not None and len(label) > 0:
351-
label_index = int(label[0])
352-
actual_label_output = OutputScore(
353-
score=100, index=label_index, label=self.classes[label_index]
354-
)
355-
356-
(
357-
predicted_scores,
358-
baselines,
359-
transformed_inputs,
360-
) = self.attribution_calculation.calculate_predicted_scores(
361-
inputs, additional_forward_args
349+
self,
350+
inputs,
351+
additional_forward_args,
352+
label,
353+
target=None,
354+
single_model_index=None,
355+
) -> Optional[List[VisualizationOutput]]:
356+
# Use all models, unless the user wants to render data for a particular one
357+
models_used = (
358+
[self.models[single_model_index]]
359+
if single_model_index is not None
360+
else self.models
362361
)
362+
results = []
363+
for model_index, model in enumerate(models_used):
364+
# Get list of model visualizations for each input
365+
actual_label_output = None
366+
if label is not None and len(label) > 0:
367+
label_index = int(label[0])
368+
actual_label_output = OutputScore(
369+
score=100, index=label_index, label=self.classes[label_index]
370+
)
371+
372+
(
373+
predicted_scores,
374+
baselines,
375+
transformed_inputs,
376+
) = self.attribution_calculation.calculate_predicted_scores(
377+
inputs, additional_forward_args, model
378+
)
363379

364-
# Filter based on UI configuration
365-
if actual_label_output is None or not self._should_keep_prediction(
366-
predicted_scores, actual_label_output
367-
):
368-
return None
369-
370-
if target is None:
371-
target = predicted_scores[0].index if len(predicted_scores) > 0 else None
372-
373-
# attributions are given per input*
374-
# inputs given to the model are described via `self.features`
375-
#
376-
# *an input contains multiple features that represent it
377-
# e.g. all the pixels that describe an image is an input
378-
379-
attrs_per_input_feature = self.attribution_calculation.calculate_attribution(
380-
baselines,
381-
transformed_inputs,
382-
additional_forward_args,
383-
target,
384-
self._config.attribution_method,
385-
self._config.attribution_arguments,
386-
)
380+
# Filter based on UI configuration
381+
if actual_label_output is None or not self._should_keep_prediction(
382+
predicted_scores, actual_label_output
383+
):
384+
continue
385+
386+
if target is None:
387+
target = (
388+
predicted_scores[0].index if len(predicted_scores) > 0 else None
389+
)
390+
391+
# attributions are given per input*
392+
# inputs given to the model are described via `self.features`
393+
#
394+
# *an input contains multiple features that represent it
395+
# e.g. all the pixels that describe an image is an input
396+
397+
attrs_per_input_feature = (
398+
self.attribution_calculation.calculate_attribution(
399+
baselines,
400+
transformed_inputs,
401+
additional_forward_args,
402+
target,
403+
self._config.attribution_method,
404+
self._config.attribution_arguments,
405+
model,
406+
)
407+
)
387408

388-
net_contrib = self.attribution_calculation.calculate_net_contrib(
389-
attrs_per_input_feature
390-
)
409+
net_contrib = self.attribution_calculation.calculate_net_contrib(
410+
attrs_per_input_feature
411+
)
391412

392-
# the features per input given
393-
features_per_input = [
394-
feature.visualize(attr, data, contrib)
395-
for feature, attr, data, contrib in zip(
396-
self.features, attrs_per_input_feature, inputs, net_contrib
413+
# the features per input given
414+
features_per_input = [
415+
feature.visualize(attr, data, contrib)
416+
for feature, attr, data, contrib in zip(
417+
self.features, attrs_per_input_feature, inputs, net_contrib
418+
)
419+
]
420+
421+
results.append(
422+
VisualizationOutput(
423+
feature_outputs=features_per_input,
424+
actual=actual_label_output,
425+
predicted=predicted_scores,
426+
active_index=target
427+
if target is not None
428+
else actual_label_output.index,
429+
# Even if we only iterated over one model, the index should be fixed
430+
# to show the index the model would have had in the list
431+
model_index=single_model_index
432+
if single_model_index is not None
433+
else model_index,
434+
)
397435
)
398-
]
399436

400-
return VisualizationOutput(
401-
feature_outputs=features_per_input,
402-
actual=actual_label_output,
403-
predicted=predicted_scores,
404-
active_index=target if target is not None else actual_label_output.index,
405-
)
437+
return results if results else None
406438

407-
def _get_outputs(self) -> List[Tuple[VisualizationOutput, SampleCache]]:
439+
def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
408440
batch_data = next(self._dataset_iter)
409441
vis_outputs = []
410442

captum/insights/attr_vis/attribution_calculation.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,39 +41,46 @@ def __init__(
4141
self.features = features
4242
self.score_func = score_func
4343
self.use_label_for_attr = use_label_for_attr
44+
self.baseline_cache = {}
45+
self.transformed_input_cache = {}
4446

4547
def calculate_predicted_scores(
46-
self, inputs, additional_forward_args
48+
self, inputs, additional_forward_args, model
4749
) -> Tuple[
4850
List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...]
4951
]:
50-
net = self.models[0] # TODO process multiple models
51-
52-
# initialize baselines
53-
baseline_transforms_len = 1 # todo support multiple baselines
54-
baselines: List[List[Optional[Tensor]]] = [
55-
[None] * len(self.features) for _ in range(baseline_transforms_len)
56-
]
57-
transformed_inputs = list(inputs)
58-
59-
for feature_i, feature in enumerate(self.features):
60-
transformed_inputs[feature_i] = self._transform(
61-
feature.input_transforms, transformed_inputs[feature_i], True
62-
)
63-
for baseline_i in range(baseline_transforms_len):
64-
if baseline_i > len(feature.baseline_transforms) - 1:
65-
baselines[baseline_i][feature_i] = torch.zeros_like(
66-
transformed_inputs[feature_i]
67-
)
68-
else:
69-
baselines[baseline_i][feature_i] = self._transform(
70-
[feature.baseline_transforms[baseline_i]],
71-
transformed_inputs[feature_i],
72-
True,
73-
)
74-
75-
baselines = cast(List[List[Tensor]], baselines)
76-
baselines_group = [tuple(b) for b in baselines]
52+
net = model
53+
54+
# Check to see if these inputs already have caches baselines and transformed inputs
55+
hashableInputs = tuple(inputs)
56+
if hashableInputs in self.baseline_cache:
57+
baselines_group = self.baseline_cache[hashableInputs]
58+
transformed_inputs = self.transformed_input_cache[hashableInputs]
59+
else:
60+
# Initialize baselines
61+
baseline_transforms_len = 1 # todo support multiple baselines
62+
baselines: List[List[Optional[Tensor]]] = [
63+
[None] * len(self.features) for _ in range(baseline_transforms_len)
64+
]
65+
transformed_inputs = list(inputs)
66+
for feature_i, feature in enumerate(self.features):
67+
transformed_inputs[feature_i] = self._transform(
68+
feature.input_transforms, transformed_inputs[feature_i], True
69+
)
70+
for baseline_i in range(baseline_transforms_len):
71+
if baseline_i > len(feature.baseline_transforms) - 1:
72+
baselines[baseline_i][feature_i] = torch.zeros_like(
73+
transformed_inputs[feature_i]
74+
)
75+
else:
76+
baselines[baseline_i][feature_i] = self._transform(
77+
[feature.baseline_transforms[baseline_i]],
78+
transformed_inputs[feature_i],
79+
True,
80+
)
81+
82+
baselines = cast(List[List[Tensor]], baselines)
83+
baselines_group = [tuple(b) for b in baselines]
7784

7885
outputs = _run_forward(
7986
net,
@@ -95,6 +102,9 @@ def calculate_predicted_scores(
95102

96103
predicted_scores = self._get_labels_from_scores(scores, predicted)
97104

105+
self.baseline_cache[hashableInputs] = baselines_group
106+
self.transformed_input_cache[hashableInputs] = transformed_inputs
107+
98108
return predicted_scores, baselines_group, tuple(transformed_inputs)
99109

100110
def calculate_attribution(
@@ -105,8 +115,9 @@ def calculate_attribution(
105115
label: Optional[Union[Tensor]],
106116
attribution_method_name: str,
107117
attribution_arguments: Dict,
118+
model: Module,
108119
) -> Tuple[Tensor, ...]:
109-
net = self.models[0]
120+
net = model
110121
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name]
111122
attribution_method = attribution_cls(net)
112123
param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name]

captum/insights/attr_vis/frontend/src/App.module.css

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@
5858
padding: 12px 8px;
5959
}
6060

61-
.filter-panel__column__title,
62-
.panel__column__title {
61+
.filter-panel__column__title {
6362
font-weight: bold;
6463
color: #1c1e21;
6564
padding-bottom: 12px;
@@ -164,12 +163,19 @@
164163
padding: 24px;
165164
background: white;
166165
border-radius: 8px;
167-
display: flex;
168166
box-shadow: 0px 3px 6px 0px rgba(0, 0, 0, 0.18);
169167
transition: opacity 0.2s; /* for loading */
170168
overflow-y: scroll;
171169
}
172170

171+
.panel__column__title {
172+
font-weight: 700;
173+
border-bottom: 2px solid #c1c1c1;
174+
color: #1c1e21;
175+
padding-bottom: 2px;
176+
margin-bottom: 15px;
177+
}
178+
173179
.panel--loading {
174180
opacity: 0.5;
175181
pointer-events: none; /* disables all interactions inside panel */
@@ -346,3 +352,25 @@
346352
transform: rotate(360deg);
347353
}
348354
}
355+
356+
.visualization-container {
357+
display: flex;
358+
}
359+
360+
.model-number {
361+
display: block;
362+
height: 2em;
363+
font-size: 16px;
364+
font-weight: 800;
365+
}
366+
367+
.model-number-spacer {
368+
display: block;
369+
height: 2em;
370+
}
371+
372+
.model-separator {
373+
width: 100%;
374+
border-bottom: 2px solid #c1c1c1;
375+
margin: 10px 0px;
376+
}

0 commit comments

Comments
 (0)