diff --git a/captum/insights/attr_vis/app.py b/captum/insights/attr_vis/app.py index 4950757d48..0137710b76 100644 --- a/captum/insights/attr_vis/app.py +++ b/captum/insights/attr_vis/app.py @@ -3,6 +3,7 @@ from typing import ( Any, Callable, + cast, Dict, Iterable, List, @@ -76,7 +77,7 @@ def _get_context(): VisualizationOutput = namedtuple( - "VisualizationOutput", "feature_outputs actual predicted active_index" + "VisualizationOutput", "feature_outputs actual predicted active_index model_index" ) Contribution = namedtuple("Contribution", "name percent") SampleCache = namedtuple("SampleCache", "inputs additional_forward_args label") @@ -149,11 +150,8 @@ def __init__( r""" Args: - models (torch.nn.module): PyTorch module (model) for attribution + models (torch.nn.module): One or more PyTorch modules (models) for attribution visualization. - We plan to support visualizing and comparing multiple models - in the future, but currently this supports only a single - model. classes (list of string): List of strings corresponding to the names of classes for classification. features (list of BaseFeature): List of BaseFeatures, which correspond @@ -195,6 +193,7 @@ class scores. self.classes = classes self.features = features self.dataset = dataset + self.models = models self.attribution_calculation = AttributionCalculation( models, classes, features, score_func, use_label_for_attr ) @@ -203,13 +202,21 @@ class scores. self._dataset_iter = iter(dataset) def _calculate_attribution_from_cache( - self, index: int, target: Optional[Tensor] + self, input_index: int, model_index: int, target: Optional[Tensor] ) -> Optional[VisualizationOutput]: - c = self._outputs[index][1] - return self._calculate_vis_output( - c.inputs, c.additional_forward_args, c.label, torch.tensor(target) + c = self._outputs[input_index][1] + result = self._calculate_vis_output( + c.inputs, + c.additional_forward_args, + c.label, + torch.tensor(target), + model_index, ) + if not result: + return None + return result[0] + def _update_config(self, settings): self._config = FilterConfig( attribution_method=settings["attribution_method"], @@ -344,67 +351,97 @@ def _should_keep_prediction( return True def _calculate_vis_output( - self, inputs, additional_forward_args, label, target=None - ) -> Optional[VisualizationOutput]: - actual_label_output = None - if label is not None and len(label) > 0: - label_index = int(label[0]) - actual_label_output = OutputScore( - score=100, index=label_index, label=self.classes[label_index] - ) - - ( - predicted_scores, - baselines, - transformed_inputs, - ) = self.attribution_calculation.calculate_predicted_scores( - inputs, additional_forward_args + self, + inputs, + additional_forward_args, + label, + target=None, + single_model_index=None, + ) -> Optional[List[VisualizationOutput]]: + # Use all models, unless the user wants to render data for a particular one + models_used = ( + [self.models[single_model_index]] + if single_model_index is not None + else self.models ) + results = [] + for model_index, model in enumerate(models_used): + # Get list of model visualizations for each input + actual_label_output = None + if label is not None and len(label) > 0: + label_index = int(label[0]) + actual_label_output = OutputScore( + score=100, index=label_index, label=self.classes[label_index] + ) + + ( + predicted_scores, + baselines, + transformed_inputs, + ) = self.attribution_calculation.calculate_predicted_scores( + inputs, additional_forward_args, model + ) - # Filter based on UI configuration - if actual_label_output is None or not self._should_keep_prediction( - predicted_scores, actual_label_output - ): - return None - - if target is None: - target = predicted_scores[0].index if len(predicted_scores) > 0 else None - - # attributions are given per input* - # inputs given to the model are described via `self.features` - # - # *an input contains multiple features that represent it - # e.g. all the pixels that describe an image is an input - - attrs_per_input_feature = self.attribution_calculation.calculate_attribution( - baselines, - transformed_inputs, - additional_forward_args, - target, - self._config.attribution_method, - self._config.attribution_arguments, - ) + # Filter based on UI configuration + if actual_label_output is None or not self._should_keep_prediction( + predicted_scores, actual_label_output + ): + continue + + if target is None: + target = ( + predicted_scores[0].index if len(predicted_scores) > 0 else None + ) + + # attributions are given per input* + # inputs given to the model are described via `self.features` + # + # *an input contains multiple features that represent it + # e.g. all the pixels that describe an image is an input + + attrs_per_input_feature = ( + self.attribution_calculation.calculate_attribution( + baselines, + transformed_inputs, + additional_forward_args, + target, + self._config.attribution_method, + self._config.attribution_arguments, + model, + ) + ) - net_contrib = self.attribution_calculation.calculate_net_contrib( - attrs_per_input_feature - ) + net_contrib = self.attribution_calculation.calculate_net_contrib( + attrs_per_input_feature + ) - # the features per input given - features_per_input = [ - feature.visualize(attr, data, contrib) - for feature, attr, data, contrib in zip( - self.features, attrs_per_input_feature, inputs, net_contrib + # the features per input given + features_per_input = [ + feature.visualize(attr, data, contrib) + for feature, attr, data, contrib in zip( + self.features, attrs_per_input_feature, inputs, net_contrib + ) + ] + + results.append( + VisualizationOutput( + feature_outputs=features_per_input, + actual=actual_label_output, + predicted=predicted_scores, + active_index=target + if target is not None + else actual_label_output.index, + # Even if we only iterated over one model, the index should be fixed + # to show the index the model would have had in the list + model_index=single_model_index + if single_model_index is not None + else model_index, + ) ) - ] - return VisualizationOutput( - feature_outputs=features_per_input, - actual=actual_label_output, - predicted=predicted_scores, - active_index=target if target is not None else actual_label_output.index, - ) + return results if results else None - def _get_outputs(self) -> List[Tuple[VisualizationOutput, SampleCache]]: + def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]: batch_data = next(self._dataset_iter) vis_outputs = [] diff --git a/captum/insights/attr_vis/attribution_calculation.py b/captum/insights/attr_vis/attribution_calculation.py index 373a1480d4..1e67124d6f 100644 --- a/captum/insights/attr_vis/attribution_calculation.py +++ b/captum/insights/attr_vis/attribution_calculation.py @@ -41,42 +41,49 @@ def __init__( self.features = features self.score_func = score_func self.use_label_for_attr = use_label_for_attr + self.baseline_cache: dict = {} + self.transformed_input_cache: dict = {} def calculate_predicted_scores( - self, inputs, additional_forward_args + self, inputs, additional_forward_args, model ) -> Tuple[ List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...] ]: - net = self.models[0] # TODO process multiple models - - # initialize baselines - baseline_transforms_len = 1 # todo support multiple baselines - baselines: List[List[Optional[Tensor]]] = [ - [None] * len(self.features) for _ in range(baseline_transforms_len) - ] - transformed_inputs = list(inputs) - - for feature_i, feature in enumerate(self.features): - transformed_inputs[feature_i] = self._transform( - feature.input_transforms, transformed_inputs[feature_i], True - ) - for baseline_i in range(baseline_transforms_len): - if baseline_i > len(feature.baseline_transforms) - 1: - baselines[baseline_i][feature_i] = torch.zeros_like( - transformed_inputs[feature_i] - ) - else: - baselines[baseline_i][feature_i] = self._transform( - [feature.baseline_transforms[baseline_i]], - transformed_inputs[feature_i], - True, - ) - - baselines = cast(List[List[Tensor]], baselines) - baselines_group = [tuple(b) for b in baselines] + # Check to see if these inputs already have caches baselines and transformed inputs + hashableInputs = tuple(inputs) + if hashableInputs in self.baseline_cache: + baselines_group = self.baseline_cache[hashableInputs] + transformed_inputs = self.transformed_input_cache[hashableInputs] + else: + # Initialize baselines + baseline_transforms_len = 1 # todo support multiple baselines + baselines: List[List[Optional[Tensor]]] = [ + [None] * len(self.features) for _ in range(baseline_transforms_len) + ] + transformed_inputs = list(inputs) + for feature_i, feature in enumerate(self.features): + transformed_inputs[feature_i] = self._transform( + feature.input_transforms, transformed_inputs[feature_i], True + ) + for baseline_i in range(baseline_transforms_len): + if baseline_i > len(feature.baseline_transforms) - 1: + baselines[baseline_i][feature_i] = torch.zeros_like( + transformed_inputs[feature_i] + ) + else: + baselines[baseline_i][feature_i] = self._transform( + [feature.baseline_transforms[baseline_i]], + transformed_inputs[feature_i], + True, + ) + + baselines = cast(List[List[Optional[Tensor]]], baselines) + baselines_group = [tuple(b) for b in baselines] + self.baseline_cache[hashableInputs] = baselines_group + self.transformed_input_cache[hashableInputs] = transformed_inputs outputs = _run_forward( - net, + model, tuple(transformed_inputs), additional_forward_args=additional_forward_args, ) @@ -105,10 +112,10 @@ def calculate_attribution( label: Optional[Union[Tensor]], attribution_method_name: str, attribution_arguments: Dict, + model: Module, ) -> Tuple[Tensor, ...]: - net = self.models[0] attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name] - attribution_method = attribution_cls(net) + attribution_method = attribution_cls(model) param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name] if param_config.post_process: for k, v in attribution_arguments.items(): diff --git a/captum/insights/attr_vis/frontend/src/App.module.css b/captum/insights/attr_vis/frontend/src/App.module.css index fd4f894d07..0c658f89bb 100644 --- a/captum/insights/attr_vis/frontend/src/App.module.css +++ b/captum/insights/attr_vis/frontend/src/App.module.css @@ -58,8 +58,7 @@ padding: 12px 8px; } -.filter-panel__column__title, -.panel__column__title { +.filter-panel__column__title { font-weight: bold; color: #1c1e21; padding-bottom: 12px; @@ -164,12 +163,19 @@ padding: 24px; background: white; border-radius: 8px; - display: flex; box-shadow: 0px 3px 6px 0px rgba(0, 0, 0, 0.18); transition: opacity 0.2s; /* for loading */ overflow-y: scroll; } +.panel__column__title { + font-weight: 700; + border-bottom: 2px solid #c1c1c1; + color: #1c1e21; + padding-bottom: 2px; + margin-bottom: 15px; +} + .panel--loading { opacity: 0.5; pointer-events: none; /* disables all interactions inside panel */ @@ -346,3 +352,25 @@ transform: rotate(360deg); } } + +.visualization-container { + display: flex; +} + +.model-number { + display: block; + height: 2em; + font-size: 16px; + font-weight: 800; +} + +.model-number-spacer { + display: block; + height: 2em; +} + +.model-separator { + width: 100%; + border-bottom: 2px solid #c1c1c1; + margin: 10px 0px; +} diff --git a/captum/insights/attr_vis/frontend/src/App.tsx b/captum/insights/attr_vis/frontend/src/App.tsx index b9bfc130ed..fb4ee1b275 100644 --- a/captum/insights/attr_vis/frontend/src/App.tsx +++ b/captum/insights/attr_vis/frontend/src/App.tsx @@ -5,14 +5,20 @@ import cx from "./utils/cx"; import Spinner from "./components/Spinner"; import FilterContainer from "./components/FilterContainer"; import Visualization from "./components/Visualization"; +import VisualizationGroupDisplay from "./components/VisualizationGroup"; import "./App.css"; -import { VisualizationOutput } from "./models/visualizationOutput"; +import { VisualizationGroup } from "./models/visualizationOutput"; import { FilterConfig } from "./models/filter"; interface VisualizationsProps { loading: boolean; - data: VisualizationOutput[]; - onTargetClick: (labelIndex: number, instance: number, callback: () => void) => void; + data: VisualizationGroup[]; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => void; } function Visualizations(props: VisualizationsProps) { @@ -41,11 +47,11 @@ function Visualizations(props: VisualizationsProps) { } return ( <div className={styles.viz}> - {props.data.map((v, i) => ( - <Visualization - data={v} - instance={i} + {props.data.map((vg, i) => ( + <VisualizationGroupDisplay + data={vg} key={i} + inputIndex={i} onTargetClick={props.onTargetClick} /> ))} @@ -57,9 +63,14 @@ interface AppBaseProps { fetchInit: () => void; fetchData: (filter_config: FilterConfig) => void; config: any; - data: VisualizationOutput[]; + data: VisualizationGroup[]; loading: boolean; - onTargetClick: (labelIndex: number, instance: number, callback: () => void) => void; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => void; } class AppBase extends React.Component<AppBaseProps> { diff --git a/captum/insights/attr_vis/frontend/src/WebApp.tsx b/captum/insights/attr_vis/frontend/src/WebApp.tsx index 57989c8c57..7390565ab6 100644 --- a/captum/insights/attr_vis/frontend/src/WebApp.tsx +++ b/captum/insights/attr_vis/frontend/src/WebApp.tsx @@ -1,12 +1,11 @@ import React from "react"; import AppBase from "./App"; -import { FilterConfig } from './models/filter'; -import { VisualizationOutput } from "./models/visualizationOutput"; +import { FilterConfig } from "./models/filter"; +import { VisualizationGroup } from "./models/visualizationOutput"; import { InsightsConfig } from "./models/insightsConfig"; - interface WebAppState { - data: VisualizationOutput[]; + data: VisualizationGroup[]; config: InsightsConfig; loading: boolean; } @@ -20,17 +19,17 @@ class WebApp extends React.Component<{}, WebAppState> { classes: [], methods: [], method_arguments: {}, - selected_method: "" + selected_method: "", }, - loading: false + loading: false, }; this._fetchInit(); } _fetchInit = () => { fetch("init") - .then(r => r.json()) - .then(r => this.setState({ config: r })); + .then((r) => r.json()) + .then((r) => this.setState({ config: r })); }; fetchData = (filter_config: FilterConfig) => { @@ -38,26 +37,31 @@ class WebApp extends React.Component<{}, WebAppState> { fetch("fetch", { method: "POST", headers: { - "Content-Type": "application/json" + "Content-Type": "application/json", }, - body: JSON.stringify(filter_config) + body: JSON.stringify(filter_config), }) - .then(response => response.json()) - .then(response => this.setState({ data: response, loading: false })); + .then((response) => response.json()) + .then((response) => this.setState({ data: response, loading: false })); }; - onTargetClick = (labelIndex: number, instance: number, callback: () => void) => { + onTargetClick = ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => { fetch("attribute", { method: "POST", headers: { - "Content-Type": "application/json" + "Content-Type": "application/json", }, - body: JSON.stringify({ labelIndex, instance }) + body: JSON.stringify({ labelIndex, inputIndex, modelIndex }), }) - .then(response => response.json()) - .then(response => { + .then((response) => response.json()) + .then((response) => { const data = this.state.data ?? []; - data[instance] = response; + data[inputIndex][modelIndex] = response; this.setState({ data }); callback(); }); diff --git a/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx b/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx index e0652cb104..d169145bf2 100644 --- a/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx +++ b/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx @@ -10,18 +10,17 @@ interface ClassFilterProps { } function ClassFilter(props: ClassFilterProps) { - - const handleAddition = (newTag: { id: number | string, name: string }) => { + const handleAddition = (newTag: { id: number | string; name: string }) => { /** - * Need this type check as we expect tagId to be number while the `react-tag-autocomplete` has - * id as number | string. + * Need this type check as we expect tagId to be number while the `react-tag-autocomplete` has + * id as number | string. */ - if(typeof newTag.id === 'string') { + if (typeof newTag.id === "string") { throw Error("Invalid tag id received from ReactTags"); } else { - props.handleClassAdd({id: newTag.id, name: newTag.name}); + props.handleClassAdd({ id: newTag.id, name: newTag.name }); } - } + }; return ( <ReactTags diff --git a/captum/insights/attr_vis/frontend/src/components/Contributions.tsx b/captum/insights/attr_vis/frontend/src/components/Contributions.tsx index d3fc8c39ec..18c401547a 100644 --- a/captum/insights/attr_vis/frontend/src/components/Contributions.tsx +++ b/captum/insights/attr_vis/frontend/src/components/Contributions.tsx @@ -8,24 +8,28 @@ interface ContributionsProps { } function Contributions(props: ContributionsProps) { - return <>{props.feature_outputs.map((f) => { - // pad bar height so features with 0 contribution can still be seen - // in graph - const contribution = f.contribution * 100; - const bar_height = contribution > 10 ? contribution : contribution + 10; - return ( - <div className={styles["bar-chart__group"]}> - <div - className={styles["bar-chart__group__bar"]} - style={{ - height: bar_height + "px", - backgroundColor: calcHSLFromScore(contribution), - }} - /> - <div className={styles["bar-chart__group__title"]}>{f.name}</div> - </div> - ); - })}</> + return ( + <> + {props.feature_outputs.map((f) => { + // pad bar height so features with 0 contribution can still be seen + // in graph + const contribution = f.contribution * 100; + const bar_height = contribution > 10 ? contribution : contribution + 10; + return ( + <div className={styles["bar-chart__group"]}> + <div + className={styles["bar-chart__group__bar"]} + style={{ + height: bar_height + "px", + backgroundColor: calcHSLFromScore(contribution), + }} + /> + <div className={styles["bar-chart__group__title"]}>{f.name}</div> + </div> + ); + })} + </> + ); } export default Contributions; diff --git a/captum/insights/attr_vis/frontend/src/components/Feature.tsx b/captum/insights/attr_vis/frontend/src/components/Feature.tsx index 20f10be4b3..23bd352958 100644 --- a/captum/insights/attr_vis/frontend/src/components/Feature.tsx +++ b/captum/insights/attr_vis/frontend/src/components/Feature.tsx @@ -7,6 +7,7 @@ import { FeatureOutput } from "../models/visualizationOutput"; interface FeatureProps<T> { data: T; + hideHeaders?: boolean; } type ImageFeatureProps = FeatureProps<{ @@ -18,10 +19,13 @@ type ImageFeatureProps = FeatureProps<{ function ImageFeature(props: ImageFeatureProps) { return ( <> - <div className={styles["panel__column__title"]}> - {props.data.name} (Image) - </div> + {props.hideHeaders && ( + <div className={styles["panel__column__title"]}> + {props.data.name} (Image) + </div> + )} <div className={styles["panel__column__body"]}> + <div className={styles["model-number-spacer"]} /> <div className={styles.gallery}> <div className={styles["gallery__item"]}> <div className={styles["gallery__item__image"]}> @@ -73,10 +77,15 @@ function TextFeature(props: TextFeatureProps) { }); return ( <> - <div className={styles["panel__column__title"]}> - {props.data.name} (Text) + {props.hideHeaders && ( + <div className={styles["panel__column__title"]}> + {props.data.name} (Text) + </div> + )} + <div className={styles["panel__column__body"]}> + <div className={styles["model-number-spacer"]} /> + {color_words} </div> - <div className={styles["panel__column__body"]}>{color_words}</div> </> ); } @@ -123,13 +132,13 @@ function GeneralFeature(props: GeneralFeatureProps) { ); } -function Feature(props: {data: FeatureOutput}) { +function Feature(props: { data: FeatureOutput; hideHeaders: boolean }) { const data = props.data; switch (data.type) { case "image": - return <ImageFeature data={data} />; + return <ImageFeature data={data} hideHeaders={props.hideHeaders} />; case "text": - return <TextFeature data={data} />; + return <TextFeature data={data} hideHeaders={props.hideHeaders} />; case "general": return <GeneralFeature data={data} />; case "empty": diff --git a/captum/insights/attr_vis/frontend/src/components/Filter.tsx b/captum/insights/attr_vis/frontend/src/components/Filter.tsx index a9c8de28ee..25e235b6ec 100644 --- a/captum/insights/attr_vis/frontend/src/components/Filter.tsx +++ b/captum/insights/attr_vis/frontend/src/components/Filter.tsx @@ -3,11 +3,14 @@ import { StringArgument, EnumArgument, NumberArgument } from "./Arguments"; import cx from "../utils/cx"; import styles from "../App.module.css"; import ClassFilter from "./ClassFilter"; -import { MethodsArguments, ArgumentConfig, ArgumentType } from "../models/insightsConfig"; +import { + MethodsArguments, + ArgumentConfig, + ArgumentType, +} from "../models/insightsConfig"; import { TagClass } from "../models/filter"; import { UserInputField } from "../models/typeHelpers"; - interface FilterProps { prediction: string; selectedMethod: string; diff --git a/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx b/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx index 50ec05d329..52871ce67a 100644 --- a/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx +++ b/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx @@ -28,7 +28,10 @@ interface FilterContainerState { method_arguments: MethodsArguments; } -class FilterContainer extends React.Component<FilterContainerProps, FilterContainerState> { +class FilterContainer extends React.Component< + FilterContainerProps, + FilterContainerState +> { constructor(props: FilterContainerProps) { super(props); const suggested_classes = props.config.classes.map((c, classId) => ({ @@ -47,7 +50,10 @@ class FilterContainer extends React.Component<FilterContainerProps, FilterContai handleClassDelete = (classId: number) => { const classes = this.state.classes.slice(0); const removed_class = classes.splice(classId, 1); - const suggested_classes = [...this.state.suggested_classes, ...removed_class]; + const suggested_classes = [ + ...this.state.suggested_classes, + ...removed_class, + ]; this.setState({ classes, suggested_classes }); }; diff --git a/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx b/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx index 9c8a9aa113..68589be021 100644 --- a/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx +++ b/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx @@ -4,15 +4,20 @@ import styles from "../App.module.css"; interface LabelButtonProps { labelIndex: number; - instance: number; + inputIndex: number; + modelIndex: number; active: boolean; - onTargetClick: (labelIndex: number, instance: number) => void; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number + ) => void; } function LabelButton(props: React.PropsWithChildren<LabelButtonProps>) { const onClick = (e: React.MouseEvent<HTMLButtonElement>) => { e.preventDefault(); - props.onTargetClick(props.labelIndex, props.instance); + props.onTargetClick(props.labelIndex, props.inputIndex, props.modelIndex); }; return ( diff --git a/captum/insights/attr_vis/frontend/src/components/Visualization.tsx b/captum/insights/attr_vis/frontend/src/components/Visualization.tsx index c0539a4f36..c70b612611 100644 --- a/captum/insights/attr_vis/frontend/src/components/Visualization.tsx +++ b/captum/insights/attr_vis/frontend/src/components/Visualization.tsx @@ -10,14 +10,22 @@ import { VisualizationOutput } from "../models/visualizationOutput"; interface VisualizationProps { data: VisualizationOutput; instance: number; - onTargetClick: (labelIndex: number, instance: number, callback: () => void) => void; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => void; } interface VisualizationState { loading: boolean; } -class Visualization extends React.Component<VisualizationProps, VisualizationState> { +class Visualization extends React.Component< + VisualizationProps, + VisualizationState +> { constructor(props: VisualizationProps) { super(props); this.state = { @@ -25,17 +33,25 @@ class Visualization extends React.Component<VisualizationProps, VisualizationSta }; } - onTargetClick = (labelIndex: number, instance: number) => { + onTargetClick = ( + labelIndex: number, + inputIndex: number, + modelIndex: number + ) => { this.setState({ loading: true }); - this.props.onTargetClick(labelIndex, instance, () => + this.props.onTargetClick(labelIndex, inputIndex, modelIndex, () => this.setState({ loading: false }) ); }; + //TODO: Refactor the visualization table as a <table> instead of columns, in order to have cleaner styling render() { const data = this.props.data; + const isFirstInGroup = this.props.data.model_index == 0; console.log("visualization", data); - const features = data.feature_outputs.map((f) => <Feature data={f} />); + const features = data.feature_outputs.map((f) => ( + <Feature data={f} hideHeaders={isFirstInGroup} /> + )); return ( <> @@ -44,22 +60,23 @@ class Visualization extends React.Component<VisualizationProps, VisualizationSta <Spinner /> </div> )} - <div - className={cx({ - [styles.panel]: true, - [styles["panel--long"]]: true, - [styles["panel--loading"]]: this.state.loading, - })} - > + {!isFirstInGroup && <div className={styles["model-separator"]} />} + <div className={styles["visualization-container"]}> <div className={styles["panel__column"]}> - <div className={styles["panel__column__title"]}>Predicted</div> + {isFirstInGroup && ( + <div className={styles["panel__column__title"]}>Predicted</div> + )} <div className={styles["panel__column__body"]}> + <div className={styles["model-number"]}> + Model {data.model_index + 1} + </div> {data.predicted.map((p) => ( <div className={cx([styles.row, styles["row--padding"]])}> <LabelButton onTargetClick={this.onTargetClick} labelIndex={p.index} - instance={this.props.instance} + inputIndex={this.props.instance} + modelIndex={this.props.data.model_index} active={p.index === data.active_index} > {p.label} ({p.score.toFixed(3)}) @@ -69,13 +86,17 @@ class Visualization extends React.Component<VisualizationProps, VisualizationSta </div> </div> <div className={styles["panel__column"]}> - <div className={styles["panel__column__title"]}>Label</div> + {isFirstInGroup && ( + <div className={styles["panel__column__title"]}>Label</div> + )} <div className={styles["panel__column__body"]}> + <div className={styles["model-number-spacer"]} /> <div className={cx([styles.row, styles["row--padding"]])}> <LabelButton onTargetClick={this.onTargetClick} labelIndex={data.actual.index} - instance={this.props.instance} + inputIndex={this.props.instance} + modelIndex={this.props.data.model_index} active={data.actual.index === data.active_index} > {data.actual.label} @@ -84,8 +105,11 @@ class Visualization extends React.Component<VisualizationProps, VisualizationSta </div> </div> <div className={styles["panel__column"]}> - <div className={styles["panel__column__title"]}>Contribution</div> + {isFirstInGroup && ( + <div className={styles["panel__column__title"]}>Contribution</div> + )} <div className={styles["panel__column__body"]}> + <div className={styles["model-number-spacer"]} /> <div className={styles["bar-chart"]}> <Contributions feature_outputs={data.feature_outputs} /> </div> diff --git a/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx b/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx new file mode 100644 index 0000000000..023699413f --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx @@ -0,0 +1,38 @@ +import React from "react"; +import styles from "../App.module.css"; +import cx from "../utils/cx"; +import Visualization from "../components/Visualization"; +import { VisualizationGroup } from "../models/visualizationOutput"; + +interface VisualizationGroupDisplayProps { + inputIndex: number; + data: VisualizationGroup; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => void; +} + +function VisualizationGroupDisplay(props: VisualizationGroupDisplayProps) { + return ( + <div + className={cx({ + [styles.panel]: true, + [styles["panel--long"]]: true, + })} + > + {props.data.map((v, i) => ( + <Visualization + data={v} + instance={props.inputIndex} + onTargetClick={props.onTargetClick} + key={i} + /> + ))} + </div> + ); +} + +export default VisualizationGroupDisplay; diff --git a/captum/insights/attr_vis/frontend/src/components/plotly.module.d.ts b/captum/insights/attr_vis/frontend/src/components/plotly.module.d.ts index ff6718a43c..d12dc32f0e 100644 --- a/captum/insights/attr_vis/frontend/src/components/plotly.module.d.ts +++ b/captum/insights/attr_vis/frontend/src/components/plotly.module.d.ts @@ -1 +1 @@ -declare module 'plotly.js-basic-dist'; \ No newline at end of file +declare module "plotly.js-basic-dist"; diff --git a/captum/insights/attr_vis/frontend/src/models/filter.ts b/captum/insights/attr_vis/frontend/src/models/filter.ts index 9044bb2ba4..4002ccd44b 100644 --- a/captum/insights/attr_vis/frontend/src/models/filter.ts +++ b/captum/insights/attr_vis/frontend/src/models/filter.ts @@ -1,11 +1,11 @@ export interface FilterConfig { - attribution_method: string; - arguments: { [key: string]: any}; - prediction: string; - classes: string[]; + attribution_method: string; + arguments: { [key: string]: any }; + prediction: string; + classes: string[]; } export interface TagClass { - id: number; - name: string; + id: number; + name: string; } diff --git a/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts b/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts index 08687ff0c0..54afb547ee 100644 --- a/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts +++ b/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts @@ -1,30 +1,30 @@ export enum ArgumentType { - Number = "number", - Enum = "enum", - String = "string", - Boolean = "boolean" + Number = "number", + Enum = "enum", + String = "string", + Boolean = "boolean", } export type GenericArgumentConfig<T> = { - value: T; - limit: T[]; -} + value: T; + limit: T[]; +}; export type ArgumentConfig = - { type: ArgumentType.Number } & GenericArgumentConfig<number> | - { type: ArgumentType.Enum } & GenericArgumentConfig<string> | - { type: ArgumentType.String } & { value: string } | - { type: ArgumentType.Boolean } & { value: boolean } + | ({ type: ArgumentType.Number } & GenericArgumentConfig<number>) + | ({ type: ArgumentType.Enum } & GenericArgumentConfig<string>) + | ({ type: ArgumentType.String } & { value: string }) + | ({ type: ArgumentType.Boolean } & { value: boolean }); export interface MethodsArguments { - [method_name: string]: { - [arg_name: string]: ArgumentConfig; - } + [method_name: string]: { + [arg_name: string]: ArgumentConfig; + }; } export interface InsightsConfig { - classes: string[]; - methods: string[]; - method_arguments: MethodsArguments; - selected_method: string; -} \ No newline at end of file + classes: string[]; + methods: string[]; + method_arguments: MethodsArguments; + selected_method: string; +} diff --git a/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts b/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts index ad17002b97..b1386f36ef 100644 --- a/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts +++ b/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts @@ -1 +1 @@ -export type UserInputField = HTMLInputElement | HTMLSelectElement; \ No newline at end of file +export type UserInputField = HTMLInputElement | HTMLSelectElement; diff --git a/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts b/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts index 26fb46b94d..923d5e5438 100644 --- a/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts +++ b/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts @@ -1,30 +1,41 @@ interface OutputScore { - label: string; - index: number; - score: number; + label: string; + index: number; + score: number; } export enum FeatureType { - TEXT = "text", - IMAGE = "image", - GENERAL = "general", - EMPTY = "empty" + TEXT = "text", + IMAGE = "image", + GENERAL = "general", + EMPTY = "empty", } type GenericFeatureOutput<F extends FeatureType, T> = { - type: F, - name: string, - contribution: number + type: F; + name: string; + contribution: number; } & T; -export type FeatureOutput = GenericFeatureOutput<FeatureType.TEXT, { base: number[], modified: number[] }> - | GenericFeatureOutput<FeatureType.IMAGE, { base: string, modified: string }> - | GenericFeatureOutput<FeatureType.GENERAL, { base: number[], modified: number[]}> - | GenericFeatureOutput<FeatureType.EMPTY, {}> +export type FeatureOutput = + | GenericFeatureOutput< + FeatureType.TEXT, + { base: number[]; modified: number[] } + > + | GenericFeatureOutput<FeatureType.IMAGE, { base: string; modified: string }> + | GenericFeatureOutput< + FeatureType.GENERAL, + { base: number[]; modified: number[] } + > + | GenericFeatureOutput<FeatureType.EMPTY, {}>; export interface VisualizationOutput { - feature_outputs: FeatureOutput[]; - actual: OutputScore; - predicted: OutputScore[]; - active_index: number; + model_index: number; + feature_outputs: FeatureOutput[]; + actual: OutputScore; + predicted: OutputScore[]; + active_index: number; } + +//When multiple models are compared, visualizations are grouped together +export type VisualizationGroup = VisualizationOutput[]; diff --git a/captum/insights/attr_vis/frontend/tsconfig.json b/captum/insights/attr_vis/frontend/tsconfig.json index f2850b7161..af10394b4c 100644 --- a/captum/insights/attr_vis/frontend/tsconfig.json +++ b/captum/insights/attr_vis/frontend/tsconfig.json @@ -1,11 +1,7 @@ { "compilerOptions": { "target": "es5", - "lib": [ - "dom", - "dom.iterable", - "esnext" - ], + "lib": ["dom", "dom.iterable", "esnext"], "allowJs": true, "skipLibCheck": true, "esModuleInterop": true, @@ -19,7 +15,5 @@ "noEmit": true, "jsx": "react" }, - "include": [ - "src" - ] + "include": ["src"] } diff --git a/captum/insights/attr_vis/frontend/widget/src/extension.js b/captum/insights/attr_vis/frontend/widget/src/extension.js index 3ee405bd23..4a5214baeb 100644 --- a/captum/insights/attr_vis/frontend/widget/src/extension.js +++ b/captum/insights/attr_vis/frontend/widget/src/extension.js @@ -14,13 +14,13 @@ if (window.require) { window.require.config({ map: { "*": { - "jupyter-captum-insights": "nbextensions/jupyter-captum-insights/index" - } - } + "jupyter-captum-insights": "nbextensions/jupyter-captum-insights/index", + }, + }, }); } // Export the required load_ipython_extension module.exports = { - load_ipython_extension: function() {} + load_ipython_extension: function () {}, }; diff --git a/captum/insights/attr_vis/frontend/widget/webpack.config.js b/captum/insights/attr_vis/frontend/widget/webpack.config.js index 5937dc8a23..c292774dbc 100644 --- a/captum/insights/attr_vis/frontend/widget/webpack.config.js +++ b/captum/insights/attr_vis/frontend/widget/webpack.config.js @@ -1,118 +1,126 @@ -var path = require('path'); -var version = require('../package.json').version; +var path = require("path"); +var version = require("../package.json").version; // Custom webpack rules are generally the same for all webpack bundles, hence // stored in a separate local variable. var rules = [ - { test: /\.module.css$/, use: [ - 'style-loader', - { - loader: 'css-loader', - options: { - modules: true - } + { + test: /\.module.css$/, + use: [ + "style-loader", + { + loader: "css-loader", + options: { + modules: true, }, - ] - }, - { test: /^((?!\.module).)*.css$/, use: ['style-loader', 'css-loader'] }, - { - test: /\.(js|ts|tsx)$/, - exclude: /node_modules/, - loaders: 'babel-loader', - options: { - presets: ['@babel/preset-react', '@babel/preset-env', '@babel/preset-typescript'], - plugins: [ - "@babel/plugin-proposal-class-properties" - ], }, - } -] + ], + }, + { test: /^((?!\.module).)*.css$/, use: ["style-loader", "css-loader"] }, + { + test: /\.(js|ts|tsx)$/, + exclude: /node_modules/, + loaders: "babel-loader", + options: { + presets: [ + "@babel/preset-react", + "@babel/preset-env", + "@babel/preset-typescript", + ], + plugins: ["@babel/plugin-proposal-class-properties"], + }, + }, +]; -var extensions = ['.js', '.ts', '.tsx'] +var extensions = [".js", ".ts", ".tsx"]; module.exports = [ - {// Notebook extension - // - // This bundle only contains the part of the JavaScript that is run on - // load of the notebook. This section generally only performs - // some configuration for requirejs, and provides the legacy - // "load_ipython_extension" function which is required for any notebook - // extension. - // - mode: 'production', - entry: './src/extension.js', - output: { - filename: 'extension.js', - path: path.resolve(__dirname, '..', '..', 'widget', 'static'), - libraryTarget: 'amd' - }, - resolveLoader: { - modules: ['../node_modules'], - extensions: extensions - }, - resolve: { - modules: ['../node_modules'] - }, + { + // Notebook extension + // + // This bundle only contains the part of the JavaScript that is run on + // load of the notebook. This section generally only performs + // some configuration for requirejs, and provides the legacy + // "load_ipython_extension" function which is required for any notebook + // extension. + // + mode: "production", + entry: "./src/extension.js", + output: { + filename: "extension.js", + path: path.resolve(__dirname, "..", "..", "widget", "static"), + libraryTarget: "amd", }, - {// Bundle for the notebook containing the custom widget views and models - // - // This bundle contains the implementation for the custom widget views and - // custom widget. - // It must be an amd module - // - mode: 'production', - entry: './src/index.js', - output: { - filename: 'index.js', - path: path.resolve(__dirname, '..', '..', 'widget', 'static'), - libraryTarget: 'amd' - }, - devtool: 'source-map', - module: { - rules: rules, - }, - resolveLoader: { - modules: ['../node_modules'] - }, - resolve: { - modules: ['../node_modules'], - extensions: extensions - }, - externals: ['@jupyter-widgets/base'] + resolveLoader: { + modules: ["../node_modules"], + extensions: extensions, }, - {// Embeddable jupyter-captum-insights bundle - // - // This bundle is generally almost identical to the notebook bundle - // containing the custom widget views and models. - // - // The only difference is in the configuration of the webpack public path - // for the static assets. - // - // It will be automatically distributed by unpkg to work with the static - // widget embedder. - // - // The target bundle is always `dist/index.js`, which is the path required - // by the custom widget embedder. - // - mode: 'production', - entry: './src/embed.js', - output: { - filename: 'index.js', - path: path.resolve(__dirname, '..', '..', 'widget', 'dist'), - libraryTarget: 'amd', - publicPath: 'https://unpkg.com/jupyter-captum-insights@' + version + '/dist/' - }, - devtool: 'source-map', - module: { - rules: rules - }, - resolveLoader: { - modules: ['../node_modules'] - }, - resolve: { - modules: ['../node_modules'], - extensions: extensions - }, - externals: ['@jupyter-widgets/base'] - } + resolve: { + modules: ["../node_modules"], + }, + }, + { + // Bundle for the notebook containing the custom widget views and models + // + // This bundle contains the implementation for the custom widget views and + // custom widget. + // It must be an amd module + // + mode: "production", + entry: "./src/index.js", + output: { + filename: "index.js", + path: path.resolve(__dirname, "..", "..", "widget", "static"), + libraryTarget: "amd", + }, + devtool: "source-map", + module: { + rules: rules, + }, + resolveLoader: { + modules: ["../node_modules"], + }, + resolve: { + modules: ["../node_modules"], + extensions: extensions, + }, + externals: ["@jupyter-widgets/base"], + }, + { + // Embeddable jupyter-captum-insights bundle + // + // This bundle is generally almost identical to the notebook bundle + // containing the custom widget views and models. + // + // The only difference is in the configuration of the webpack public path + // for the static assets. + // + // It will be automatically distributed by unpkg to work with the static + // widget embedder. + // + // The target bundle is always `dist/index.js`, which is the path required + // by the custom widget embedder. + // + mode: "production", + entry: "./src/embed.js", + output: { + filename: "index.js", + path: path.resolve(__dirname, "..", "..", "widget", "dist"), + libraryTarget: "amd", + publicPath: + "https://unpkg.com/jupyter-captum-insights@" + version + "/dist/", + }, + devtool: "source-map", + module: { + rules: rules, + }, + resolveLoader: { + modules: ["../node_modules"], + }, + resolve: { + modules: ["../node_modules"], + extensions: extensions, + }, + externals: ["@jupyter-widgets/base"], + }, ]; diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py index 15cee3906d..17bfd8e34f 100644 --- a/captum/insights/attr_vis/server.py +++ b/captum/insights/attr_vis/server.py @@ -42,7 +42,9 @@ def attribute(): r = request.get_json(force=True) return jsonify( namedtuple_to_dict( - visualizer._calculate_attribution_from_cache(r["instance"], r["labelIndex"]) + visualizer._calculate_attribution_from_cache( + r["inputIndex"], r["modelIndex"], r["labelIndex"] + ) ) ) diff --git a/tests/insights/test_contribution.py b/tests/insights/test_contribution.py index b9a4b01e71..0f04b59ddc 100644 --- a/tests/insights/test_contribution.py +++ b/tests/insights/test_contribution.py @@ -167,7 +167,7 @@ def test_one_feature(self): outputs = visualizer.visualize() for output in outputs: - total_contrib = sum(abs(f.contribution) for f in output.feature_outputs) + total_contrib = sum(abs(f.contribution) for f in output[0].feature_outputs) self.assertAlmostEqual(total_contrib, 1.0, places=6) def test_multi_features(self): @@ -210,7 +210,7 @@ def test_multi_features(self): outputs = visualizer.visualize() for output in outputs: - total_contrib = sum(abs(f.contribution) for f in output.feature_outputs) + total_contrib = sum(abs(f.contribution) for f in output[0].feature_outputs) self.assertAlmostEqual(total_contrib, 1.0, places=6) # TODO: add test for multiple models (related to TODO in captum/insights/api.py)