Skip to content

Add ability to compare multiple models in Captum Insights #551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 100 additions & 63 deletions captum/insights/attr_vis/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -76,7 +77,7 @@ def _get_context():


VisualizationOutput = namedtuple(
"VisualizationOutput", "feature_outputs actual predicted active_index"
"VisualizationOutput", "feature_outputs actual predicted active_index model_index"
)
Contribution = namedtuple("Contribution", "name percent")
SampleCache = namedtuple("SampleCache", "inputs additional_forward_args label")
Expand Down Expand Up @@ -149,11 +150,8 @@ def __init__(
r"""
Args:

models (torch.nn.module): PyTorch module (model) for attribution
models (torch.nn.module): One or more PyTorch modules (models) for attribution
visualization.
We plan to support visualizing and comparing multiple models
in the future, but currently this supports only a single
model.
classes (list of string): List of strings corresponding to the names of
classes for classification.
features (list of BaseFeature): List of BaseFeatures, which correspond
Expand Down Expand Up @@ -195,6 +193,7 @@ class scores.
self.classes = classes
self.features = features
self.dataset = dataset
self.models = models
self.attribution_calculation = AttributionCalculation(
models, classes, features, score_func, use_label_for_attr
)
Expand All @@ -203,13 +202,21 @@ class scores.
self._dataset_iter = iter(dataset)

def _calculate_attribution_from_cache(
self, index: int, target: Optional[Tensor]
self, input_index: int, model_index: int, target: Optional[Tensor]
) -> Optional[VisualizationOutput]:
c = self._outputs[index][1]
return self._calculate_vis_output(
c.inputs, c.additional_forward_args, c.label, torch.tensor(target)
c = self._outputs[input_index][1]
result = self._calculate_vis_output(
c.inputs,
c.additional_forward_args,
c.label,
torch.tensor(target),
model_index,
)

if not result:
return None
return result[0]

def _update_config(self, settings):
self._config = FilterConfig(
attribution_method=settings["attribution_method"],
Expand Down Expand Up @@ -344,67 +351,97 @@ def _should_keep_prediction(
return True

def _calculate_vis_output(
self, inputs, additional_forward_args, label, target=None
) -> Optional[VisualizationOutput]:
actual_label_output = None
if label is not None and len(label) > 0:
label_index = int(label[0])
actual_label_output = OutputScore(
score=100, index=label_index, label=self.classes[label_index]
)

(
predicted_scores,
baselines,
transformed_inputs,
) = self.attribution_calculation.calculate_predicted_scores(
inputs, additional_forward_args
self,
inputs,
additional_forward_args,
label,
target=None,
single_model_index=None,
) -> Optional[List[VisualizationOutput]]:
# Use all models, unless the user wants to render data for a particular one
models_used = (
[self.models[single_model_index]]
if single_model_index is not None
else self.models
)
results = []
for model_index, model in enumerate(models_used):
# Get list of model visualizations for each input
actual_label_output = None
if label is not None and len(label) > 0:
label_index = int(label[0])
actual_label_output = OutputScore(
score=100, index=label_index, label=self.classes[label_index]
)

(
predicted_scores,
baselines,
transformed_inputs,
) = self.attribution_calculation.calculate_predicted_scores(
inputs, additional_forward_args, model
)

# Filter based on UI configuration
if actual_label_output is None or not self._should_keep_prediction(
predicted_scores, actual_label_output
):
return None

if target is None:
target = predicted_scores[0].index if len(predicted_scores) > 0 else None

# attributions are given per input*
# inputs given to the model are described via `self.features`
#
# *an input contains multiple features that represent it
# e.g. all the pixels that describe an image is an input

attrs_per_input_feature = self.attribution_calculation.calculate_attribution(
baselines,
transformed_inputs,
additional_forward_args,
target,
self._config.attribution_method,
self._config.attribution_arguments,
)
# Filter based on UI configuration
if actual_label_output is None or not self._should_keep_prediction(
predicted_scores, actual_label_output
):
continue

if target is None:
target = (
predicted_scores[0].index if len(predicted_scores) > 0 else None
)

# attributions are given per input*
# inputs given to the model are described via `self.features`
#
# *an input contains multiple features that represent it
# e.g. all the pixels that describe an image is an input

attrs_per_input_feature = (
self.attribution_calculation.calculate_attribution(
baselines,
transformed_inputs,
additional_forward_args,
target,
self._config.attribution_method,
self._config.attribution_arguments,
model,
)
)

net_contrib = self.attribution_calculation.calculate_net_contrib(
attrs_per_input_feature
)
net_contrib = self.attribution_calculation.calculate_net_contrib(
attrs_per_input_feature
)

# the features per input given
features_per_input = [
feature.visualize(attr, data, contrib)
for feature, attr, data, contrib in zip(
self.features, attrs_per_input_feature, inputs, net_contrib
# the features per input given
features_per_input = [
feature.visualize(attr, data, contrib)
for feature, attr, data, contrib in zip(
self.features, attrs_per_input_feature, inputs, net_contrib
)
]

results.append(
VisualizationOutput(
feature_outputs=features_per_input,
actual=actual_label_output,
predicted=predicted_scores,
active_index=target
if target is not None
else actual_label_output.index,
# Even if we only iterated over one model, the index should be fixed
# to show the index the model would have had in the list
model_index=single_model_index
if single_model_index is not None
else model_index,
)
)
]

return VisualizationOutput(
feature_outputs=features_per_input,
actual=actual_label_output,
predicted=predicted_scores,
active_index=target if target is not None else actual_label_output.index,
)
return results if results else None

def _get_outputs(self) -> List[Tuple[VisualizationOutput, SampleCache]]:
def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
batch_data = next(self._dataset_iter)
vis_outputs = []

Expand Down
69 changes: 38 additions & 31 deletions captum/insights/attr_vis/attribution_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,42 +41,49 @@ def __init__(
self.features = features
self.score_func = score_func
self.use_label_for_attr = use_label_for_attr
self.baseline_cache: dict = {}
self.transformed_input_cache: dict = {}

def calculate_predicted_scores(
self, inputs, additional_forward_args
self, inputs, additional_forward_args, model
) -> Tuple[
List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...]
]:
net = self.models[0] # TODO process multiple models

# initialize baselines
baseline_transforms_len = 1 # todo support multiple baselines
baselines: List[List[Optional[Tensor]]] = [
[None] * len(self.features) for _ in range(baseline_transforms_len)
]
transformed_inputs = list(inputs)

for feature_i, feature in enumerate(self.features):
transformed_inputs[feature_i] = self._transform(
feature.input_transforms, transformed_inputs[feature_i], True
)
for baseline_i in range(baseline_transforms_len):
if baseline_i > len(feature.baseline_transforms) - 1:
baselines[baseline_i][feature_i] = torch.zeros_like(
transformed_inputs[feature_i]
)
else:
baselines[baseline_i][feature_i] = self._transform(
[feature.baseline_transforms[baseline_i]],
transformed_inputs[feature_i],
True,
)

baselines = cast(List[List[Tensor]], baselines)
baselines_group = [tuple(b) for b in baselines]
# Check to see if these inputs already have caches baselines and transformed inputs
hashableInputs = tuple(inputs)
if hashableInputs in self.baseline_cache:
baselines_group = self.baseline_cache[hashableInputs]
transformed_inputs = self.transformed_input_cache[hashableInputs]
else:
# Initialize baselines
baseline_transforms_len = 1 # todo support multiple baselines
baselines: List[List[Optional[Tensor]]] = [
[None] * len(self.features) for _ in range(baseline_transforms_len)
]
transformed_inputs = list(inputs)
for feature_i, feature in enumerate(self.features):
transformed_inputs[feature_i] = self._transform(
feature.input_transforms, transformed_inputs[feature_i], True
)
for baseline_i in range(baseline_transforms_len):
if baseline_i > len(feature.baseline_transforms) - 1:
baselines[baseline_i][feature_i] = torch.zeros_like(
transformed_inputs[feature_i]
)
else:
baselines[baseline_i][feature_i] = self._transform(
[feature.baseline_transforms[baseline_i]],
transformed_inputs[feature_i],
True,
)

baselines = cast(List[List[Optional[Tensor]]], baselines)
baselines_group = [tuple(b) for b in baselines]
self.baseline_cache[hashableInputs] = baselines_group
self.transformed_input_cache[hashableInputs] = transformed_inputs

outputs = _run_forward(
net,
model,
tuple(transformed_inputs),
additional_forward_args=additional_forward_args,
)
Expand Down Expand Up @@ -105,10 +112,10 @@ def calculate_attribution(
label: Optional[Union[Tensor]],
attribution_method_name: str,
attribution_arguments: Dict,
model: Module,
) -> Tuple[Tensor, ...]:
net = self.models[0]
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name]
attribution_method = attribution_cls(net)
attribution_method = attribution_cls(model)
param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name]
if param_config.post_process:
for k, v in attribution_arguments.items():
Expand Down
34 changes: 31 additions & 3 deletions captum/insights/attr_vis/frontend/src/App.module.css
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@
padding: 12px 8px;
}

.filter-panel__column__title,
.panel__column__title {
.filter-panel__column__title {
font-weight: bold;
color: #1c1e21;
padding-bottom: 12px;
Expand Down Expand Up @@ -164,12 +163,19 @@
padding: 24px;
background: white;
border-radius: 8px;
display: flex;
box-shadow: 0px 3px 6px 0px rgba(0, 0, 0, 0.18);
transition: opacity 0.2s; /* for loading */
overflow-y: scroll;
}

.panel__column__title {
font-weight: 700;
border-bottom: 2px solid #c1c1c1;
color: #1c1e21;
padding-bottom: 2px;
margin-bottom: 15px;
}

.panel--loading {
opacity: 0.5;
pointer-events: none; /* disables all interactions inside panel */
Expand Down Expand Up @@ -346,3 +352,25 @@
transform: rotate(360deg);
}
}

.visualization-container {
display: flex;
}

.model-number {
display: block;
height: 2em;
font-size: 16px;
font-weight: 800;
}

.model-number-spacer {
display: block;
height: 2em;
}

.model-separator {
width: 100%;
border-bottom: 2px solid #c1c1c1;
margin: 10px 0px;
}
Loading