diff --git a/captum/insights/attr_vis/app.py b/captum/insights/attr_vis/app.py
index 4950757d48..0137710b76 100644
--- a/captum/insights/attr_vis/app.py
+++ b/captum/insights/attr_vis/app.py
@@ -3,6 +3,7 @@
 from typing import (
     Any,
     Callable,
+    cast,
     Dict,
     Iterable,
     List,
@@ -76,7 +77,7 @@ def _get_context():
 
 
 VisualizationOutput = namedtuple(
-    "VisualizationOutput", "feature_outputs actual predicted active_index"
+    "VisualizationOutput", "feature_outputs actual predicted active_index model_index"
 )
 Contribution = namedtuple("Contribution", "name percent")
 SampleCache = namedtuple("SampleCache", "inputs additional_forward_args label")
@@ -149,11 +150,8 @@ def __init__(
         r"""
         Args:
 
-            models (torch.nn.module): PyTorch module (model) for attribution
+            models (torch.nn.module): One or more PyTorch modules (models) for attribution
                           visualization.
-                          We plan to support visualizing and comparing multiple models
-                          in the future, but currently this supports only a single
-                          model.
             classes (list of string): List of strings corresponding to the names of
                           classes for classification.
             features (list of BaseFeature): List of BaseFeatures, which correspond
@@ -195,6 +193,7 @@ class scores.
         self.classes = classes
         self.features = features
         self.dataset = dataset
+        self.models = models
         self.attribution_calculation = AttributionCalculation(
             models, classes, features, score_func, use_label_for_attr
         )
@@ -203,13 +202,21 @@ class scores.
         self._dataset_iter = iter(dataset)
 
     def _calculate_attribution_from_cache(
-        self, index: int, target: Optional[Tensor]
+        self, input_index: int, model_index: int, target: Optional[Tensor]
     ) -> Optional[VisualizationOutput]:
-        c = self._outputs[index][1]
-        return self._calculate_vis_output(
-            c.inputs, c.additional_forward_args, c.label, torch.tensor(target)
+        c = self._outputs[input_index][1]
+        result = self._calculate_vis_output(
+            c.inputs,
+            c.additional_forward_args,
+            c.label,
+            torch.tensor(target),
+            model_index,
         )
 
+        if not result:
+            return None
+        return result[0]
+
     def _update_config(self, settings):
         self._config = FilterConfig(
             attribution_method=settings["attribution_method"],
@@ -344,67 +351,97 @@ def _should_keep_prediction(
         return True
 
     def _calculate_vis_output(
-        self, inputs, additional_forward_args, label, target=None
-    ) -> Optional[VisualizationOutput]:
-        actual_label_output = None
-        if label is not None and len(label) > 0:
-            label_index = int(label[0])
-            actual_label_output = OutputScore(
-                score=100, index=label_index, label=self.classes[label_index]
-            )
-
-        (
-            predicted_scores,
-            baselines,
-            transformed_inputs,
-        ) = self.attribution_calculation.calculate_predicted_scores(
-            inputs, additional_forward_args
+        self,
+        inputs,
+        additional_forward_args,
+        label,
+        target=None,
+        single_model_index=None,
+    ) -> Optional[List[VisualizationOutput]]:
+        # Use all models, unless the user wants to render data for a particular one
+        models_used = (
+            [self.models[single_model_index]]
+            if single_model_index is not None
+            else self.models
         )
+        results = []
+        for model_index, model in enumerate(models_used):
+            # Get list of model visualizations for each input
+            actual_label_output = None
+            if label is not None and len(label) > 0:
+                label_index = int(label[0])
+                actual_label_output = OutputScore(
+                    score=100, index=label_index, label=self.classes[label_index]
+                )
+
+            (
+                predicted_scores,
+                baselines,
+                transformed_inputs,
+            ) = self.attribution_calculation.calculate_predicted_scores(
+                inputs, additional_forward_args, model
+            )
 
-        # Filter based on UI configuration
-        if actual_label_output is None or not self._should_keep_prediction(
-            predicted_scores, actual_label_output
-        ):
-            return None
-
-        if target is None:
-            target = predicted_scores[0].index if len(predicted_scores) > 0 else None
-
-        # attributions are given per input*
-        # inputs given to the model are described via `self.features`
-        #
-        # *an input contains multiple features that represent it
-        #   e.g. all the pixels that describe an image is an input
-
-        attrs_per_input_feature = self.attribution_calculation.calculate_attribution(
-            baselines,
-            transformed_inputs,
-            additional_forward_args,
-            target,
-            self._config.attribution_method,
-            self._config.attribution_arguments,
-        )
+            # Filter based on UI configuration
+            if actual_label_output is None or not self._should_keep_prediction(
+                predicted_scores, actual_label_output
+            ):
+                continue
+
+            if target is None:
+                target = (
+                    predicted_scores[0].index if len(predicted_scores) > 0 else None
+                )
+
+            # attributions are given per input*
+            # inputs given to the model are described via `self.features`
+            #
+            # *an input contains multiple features that represent it
+            #   e.g. all the pixels that describe an image is an input
+
+            attrs_per_input_feature = (
+                self.attribution_calculation.calculate_attribution(
+                    baselines,
+                    transformed_inputs,
+                    additional_forward_args,
+                    target,
+                    self._config.attribution_method,
+                    self._config.attribution_arguments,
+                    model,
+                )
+            )
 
-        net_contrib = self.attribution_calculation.calculate_net_contrib(
-            attrs_per_input_feature
-        )
+            net_contrib = self.attribution_calculation.calculate_net_contrib(
+                attrs_per_input_feature
+            )
 
-        # the features per input given
-        features_per_input = [
-            feature.visualize(attr, data, contrib)
-            for feature, attr, data, contrib in zip(
-                self.features, attrs_per_input_feature, inputs, net_contrib
+            # the features per input given
+            features_per_input = [
+                feature.visualize(attr, data, contrib)
+                for feature, attr, data, contrib in zip(
+                    self.features, attrs_per_input_feature, inputs, net_contrib
+                )
+            ]
+
+            results.append(
+                VisualizationOutput(
+                    feature_outputs=features_per_input,
+                    actual=actual_label_output,
+                    predicted=predicted_scores,
+                    active_index=target
+                    if target is not None
+                    else actual_label_output.index,
+                    # Even if we only iterated over one model, the index should be fixed
+                    # to show the index the model would have had in the list
+                    model_index=single_model_index
+                    if single_model_index is not None
+                    else model_index,
+                )
             )
-        ]
 
-        return VisualizationOutput(
-            feature_outputs=features_per_input,
-            actual=actual_label_output,
-            predicted=predicted_scores,
-            active_index=target if target is not None else actual_label_output.index,
-        )
+        return results if results else None
 
-    def _get_outputs(self) -> List[Tuple[VisualizationOutput, SampleCache]]:
+    def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
         batch_data = next(self._dataset_iter)
         vis_outputs = []
 
diff --git a/captum/insights/attr_vis/attribution_calculation.py b/captum/insights/attr_vis/attribution_calculation.py
index 373a1480d4..1e67124d6f 100644
--- a/captum/insights/attr_vis/attribution_calculation.py
+++ b/captum/insights/attr_vis/attribution_calculation.py
@@ -41,42 +41,49 @@ def __init__(
         self.features = features
         self.score_func = score_func
         self.use_label_for_attr = use_label_for_attr
+        self.baseline_cache: dict = {}
+        self.transformed_input_cache: dict = {}
 
     def calculate_predicted_scores(
-        self, inputs, additional_forward_args
+        self, inputs, additional_forward_args, model
     ) -> Tuple[
         List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...]
     ]:
-        net = self.models[0]  # TODO process multiple models
-
-        # initialize baselines
-        baseline_transforms_len = 1  # todo support multiple baselines
-        baselines: List[List[Optional[Tensor]]] = [
-            [None] * len(self.features) for _ in range(baseline_transforms_len)
-        ]
-        transformed_inputs = list(inputs)
-
-        for feature_i, feature in enumerate(self.features):
-            transformed_inputs[feature_i] = self._transform(
-                feature.input_transforms, transformed_inputs[feature_i], True
-            )
-            for baseline_i in range(baseline_transforms_len):
-                if baseline_i > len(feature.baseline_transforms) - 1:
-                    baselines[baseline_i][feature_i] = torch.zeros_like(
-                        transformed_inputs[feature_i]
-                    )
-                else:
-                    baselines[baseline_i][feature_i] = self._transform(
-                        [feature.baseline_transforms[baseline_i]],
-                        transformed_inputs[feature_i],
-                        True,
-                    )
-
-        baselines = cast(List[List[Tensor]], baselines)
-        baselines_group = [tuple(b) for b in baselines]
+        # Check to see if these inputs already have caches baselines and transformed inputs
+        hashableInputs = tuple(inputs)
+        if hashableInputs in self.baseline_cache:
+            baselines_group = self.baseline_cache[hashableInputs]
+            transformed_inputs = self.transformed_input_cache[hashableInputs]
+        else:
+            # Initialize baselines
+            baseline_transforms_len = 1  # todo support multiple baselines
+            baselines: List[List[Optional[Tensor]]] = [
+                [None] * len(self.features) for _ in range(baseline_transforms_len)
+            ]
+            transformed_inputs = list(inputs)
+            for feature_i, feature in enumerate(self.features):
+                transformed_inputs[feature_i] = self._transform(
+                    feature.input_transforms, transformed_inputs[feature_i], True
+                )
+                for baseline_i in range(baseline_transforms_len):
+                    if baseline_i > len(feature.baseline_transforms) - 1:
+                        baselines[baseline_i][feature_i] = torch.zeros_like(
+                            transformed_inputs[feature_i]
+                        )
+                    else:
+                        baselines[baseline_i][feature_i] = self._transform(
+                            [feature.baseline_transforms[baseline_i]],
+                            transformed_inputs[feature_i],
+                            True,
+                        )
+
+            baselines = cast(List[List[Optional[Tensor]]], baselines)
+            baselines_group = [tuple(b) for b in baselines]
+            self.baseline_cache[hashableInputs] = baselines_group
+            self.transformed_input_cache[hashableInputs] = transformed_inputs
 
         outputs = _run_forward(
-            net,
+            model,
             tuple(transformed_inputs),
             additional_forward_args=additional_forward_args,
         )
@@ -105,10 +112,10 @@ def calculate_attribution(
         label: Optional[Union[Tensor]],
         attribution_method_name: str,
         attribution_arguments: Dict,
+        model: Module,
     ) -> Tuple[Tensor, ...]:
-        net = self.models[0]
         attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name]
-        attribution_method = attribution_cls(net)
+        attribution_method = attribution_cls(model)
         param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name]
         if param_config.post_process:
             for k, v in attribution_arguments.items():
diff --git a/captum/insights/attr_vis/frontend/src/App.module.css b/captum/insights/attr_vis/frontend/src/App.module.css
index fd4f894d07..0c658f89bb 100644
--- a/captum/insights/attr_vis/frontend/src/App.module.css
+++ b/captum/insights/attr_vis/frontend/src/App.module.css
@@ -58,8 +58,7 @@
   padding: 12px 8px;
 }
 
-.filter-panel__column__title,
-.panel__column__title {
+.filter-panel__column__title {
   font-weight: bold;
   color: #1c1e21;
   padding-bottom: 12px;
@@ -164,12 +163,19 @@
   padding: 24px;
   background: white;
   border-radius: 8px;
-  display: flex;
   box-shadow: 0px 3px 6px 0px rgba(0, 0, 0, 0.18);
   transition: opacity 0.2s; /* for loading */
   overflow-y: scroll;
 }
 
+.panel__column__title {
+  font-weight: 700;
+  border-bottom: 2px solid #c1c1c1;
+  color: #1c1e21;
+  padding-bottom: 2px;
+  margin-bottom: 15px;
+}
+
 .panel--loading {
   opacity: 0.5;
   pointer-events: none; /* disables all interactions inside panel */
@@ -346,3 +352,25 @@
     transform: rotate(360deg);
   }
 }
+
+.visualization-container {
+  display: flex;
+}
+
+.model-number {
+  display: block;
+  height: 2em;
+  font-size: 16px;
+  font-weight: 800;
+}
+
+.model-number-spacer {
+  display: block;
+  height: 2em;
+}
+
+.model-separator {
+  width: 100%;
+  border-bottom: 2px solid #c1c1c1;
+  margin: 10px 0px;
+}
diff --git a/captum/insights/attr_vis/frontend/src/App.tsx b/captum/insights/attr_vis/frontend/src/App.tsx
index b9bfc130ed..fb4ee1b275 100644
--- a/captum/insights/attr_vis/frontend/src/App.tsx
+++ b/captum/insights/attr_vis/frontend/src/App.tsx
@@ -5,14 +5,20 @@ import cx from "./utils/cx";
 import Spinner from "./components/Spinner";
 import FilterContainer from "./components/FilterContainer";
 import Visualization from "./components/Visualization";
+import VisualizationGroupDisplay from "./components/VisualizationGroup";
 import "./App.css";
-import { VisualizationOutput } from "./models/visualizationOutput";
+import { VisualizationGroup } from "./models/visualizationOutput";
 import { FilterConfig } from "./models/filter";
 
 interface VisualizationsProps {
   loading: boolean;
-  data: VisualizationOutput[];
-  onTargetClick: (labelIndex: number, instance: number, callback: () => void) => void;
+  data: VisualizationGroup[];
+  onTargetClick: (
+    labelIndex: number,
+    inputIndex: number,
+    modelIndex: number,
+    callback: () => void
+  ) => void;
 }
 
 function Visualizations(props: VisualizationsProps) {
@@ -41,11 +47,11 @@ function Visualizations(props: VisualizationsProps) {
   }
   return (
     <div className={styles.viz}>
-      {props.data.map((v, i) => (
-        <Visualization
-          data={v}
-          instance={i}
+      {props.data.map((vg, i) => (
+        <VisualizationGroupDisplay
+          data={vg}
           key={i}
+          inputIndex={i}
           onTargetClick={props.onTargetClick}
         />
       ))}
@@ -57,9 +63,14 @@ interface AppBaseProps {
   fetchInit: () => void;
   fetchData: (filter_config: FilterConfig) => void;
   config: any;
-  data: VisualizationOutput[];
+  data: VisualizationGroup[];
   loading: boolean;
-  onTargetClick: (labelIndex: number, instance: number, callback: () => void) => void;
+  onTargetClick: (
+    labelIndex: number,
+    inputIndex: number,
+    modelIndex: number,
+    callback: () => void
+  ) => void;
 }
 
 class AppBase extends React.Component<AppBaseProps> {
diff --git a/captum/insights/attr_vis/frontend/src/WebApp.tsx b/captum/insights/attr_vis/frontend/src/WebApp.tsx
index 57989c8c57..7390565ab6 100644
--- a/captum/insights/attr_vis/frontend/src/WebApp.tsx
+++ b/captum/insights/attr_vis/frontend/src/WebApp.tsx
@@ -1,12 +1,11 @@
 import React from "react";
 import AppBase from "./App";
-import { FilterConfig } from './models/filter';
-import { VisualizationOutput } from "./models/visualizationOutput";
+import { FilterConfig } from "./models/filter";
+import { VisualizationGroup } from "./models/visualizationOutput";
 import { InsightsConfig } from "./models/insightsConfig";
 
-
 interface WebAppState {
-  data: VisualizationOutput[];
+  data: VisualizationGroup[];
   config: InsightsConfig;
   loading: boolean;
 }
@@ -20,17 +19,17 @@ class WebApp extends React.Component<{}, WebAppState> {
         classes: [],
         methods: [],
         method_arguments: {},
-        selected_method: ""
+        selected_method: "",
       },
-      loading: false
+      loading: false,
     };
     this._fetchInit();
   }
 
   _fetchInit = () => {
     fetch("init")
-      .then(r => r.json())
-      .then(r => this.setState({ config: r }));
+      .then((r) => r.json())
+      .then((r) => this.setState({ config: r }));
   };
 
   fetchData = (filter_config: FilterConfig) => {
@@ -38,26 +37,31 @@ class WebApp extends React.Component<{}, WebAppState> {
     fetch("fetch", {
       method: "POST",
       headers: {
-        "Content-Type": "application/json"
+        "Content-Type": "application/json",
       },
-      body: JSON.stringify(filter_config)
+      body: JSON.stringify(filter_config),
     })
-      .then(response => response.json())
-      .then(response => this.setState({ data: response, loading: false }));
+      .then((response) => response.json())
+      .then((response) => this.setState({ data: response, loading: false }));
   };
 
-  onTargetClick = (labelIndex: number, instance: number, callback: () => void) => {
+  onTargetClick = (
+    labelIndex: number,
+    inputIndex: number,
+    modelIndex: number,
+    callback: () => void
+  ) => {
     fetch("attribute", {
       method: "POST",
       headers: {
-        "Content-Type": "application/json"
+        "Content-Type": "application/json",
       },
-      body: JSON.stringify({ labelIndex, instance })
+      body: JSON.stringify({ labelIndex, inputIndex, modelIndex }),
     })
-      .then(response => response.json())
-      .then(response => {
+      .then((response) => response.json())
+      .then((response) => {
         const data = this.state.data ?? [];
-        data[instance] = response;
+        data[inputIndex][modelIndex] = response;
         this.setState({ data });
         callback();
       });
diff --git a/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx b/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx
index e0652cb104..d169145bf2 100644
--- a/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx
+++ b/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx
@@ -10,18 +10,17 @@ interface ClassFilterProps {
 }
 
 function ClassFilter(props: ClassFilterProps) {
-
-  const handleAddition = (newTag: { id: number | string, name: string }) => {
+  const handleAddition = (newTag: { id: number | string; name: string }) => {
     /**
-     * Need this type check as we expect tagId to be number while the `react-tag-autocomplete` has 
-     * id as number | string. 
+     * Need this type check as we expect tagId to be number while the `react-tag-autocomplete` has
+     * id as number | string.
      */
-    if(typeof newTag.id === 'string') {
+    if (typeof newTag.id === "string") {
       throw Error("Invalid tag id received from ReactTags");
     } else {
-      props.handleClassAdd({id: newTag.id, name: newTag.name}); 
+      props.handleClassAdd({ id: newTag.id, name: newTag.name });
     }
-  }
+  };
 
   return (
     <ReactTags
diff --git a/captum/insights/attr_vis/frontend/src/components/Contributions.tsx b/captum/insights/attr_vis/frontend/src/components/Contributions.tsx
index d3fc8c39ec..18c401547a 100644
--- a/captum/insights/attr_vis/frontend/src/components/Contributions.tsx
+++ b/captum/insights/attr_vis/frontend/src/components/Contributions.tsx
@@ -8,24 +8,28 @@ interface ContributionsProps {
 }
 
 function Contributions(props: ContributionsProps) {
-  return <>{props.feature_outputs.map((f) => {
-    // pad bar height so features with 0 contribution can still be seen
-    // in graph
-    const contribution = f.contribution * 100;
-    const bar_height = contribution > 10 ? contribution : contribution + 10;
-    return (
-      <div className={styles["bar-chart__group"]}>
-        <div
-          className={styles["bar-chart__group__bar"]}
-          style={{
-            height: bar_height + "px",
-            backgroundColor: calcHSLFromScore(contribution),
-          }}
-        />
-        <div className={styles["bar-chart__group__title"]}>{f.name}</div>
-      </div>
-    );
-  })}</>
+  return (
+    <>
+      {props.feature_outputs.map((f) => {
+        // pad bar height so features with 0 contribution can still be seen
+        // in graph
+        const contribution = f.contribution * 100;
+        const bar_height = contribution > 10 ? contribution : contribution + 10;
+        return (
+          <div className={styles["bar-chart__group"]}>
+            <div
+              className={styles["bar-chart__group__bar"]}
+              style={{
+                height: bar_height + "px",
+                backgroundColor: calcHSLFromScore(contribution),
+              }}
+            />
+            <div className={styles["bar-chart__group__title"]}>{f.name}</div>
+          </div>
+        );
+      })}
+    </>
+  );
 }
 
 export default Contributions;
diff --git a/captum/insights/attr_vis/frontend/src/components/Feature.tsx b/captum/insights/attr_vis/frontend/src/components/Feature.tsx
index 20f10be4b3..23bd352958 100644
--- a/captum/insights/attr_vis/frontend/src/components/Feature.tsx
+++ b/captum/insights/attr_vis/frontend/src/components/Feature.tsx
@@ -7,6 +7,7 @@ import { FeatureOutput } from "../models/visualizationOutput";
 
 interface FeatureProps<T> {
   data: T;
+  hideHeaders?: boolean;
 }
 
 type ImageFeatureProps = FeatureProps<{
@@ -18,10 +19,13 @@ type ImageFeatureProps = FeatureProps<{
 function ImageFeature(props: ImageFeatureProps) {
   return (
     <>
-      <div className={styles["panel__column__title"]}>
-        {props.data.name} (Image)
-      </div>
+      {props.hideHeaders && (
+        <div className={styles["panel__column__title"]}>
+          {props.data.name} (Image)
+        </div>
+      )}
       <div className={styles["panel__column__body"]}>
+        <div className={styles["model-number-spacer"]} />
         <div className={styles.gallery}>
           <div className={styles["gallery__item"]}>
             <div className={styles["gallery__item__image"]}>
@@ -73,10 +77,15 @@ function TextFeature(props: TextFeatureProps) {
   });
   return (
     <>
-      <div className={styles["panel__column__title"]}>
-        {props.data.name} (Text)
+      {props.hideHeaders && (
+        <div className={styles["panel__column__title"]}>
+          {props.data.name} (Text)
+        </div>
+      )}
+      <div className={styles["panel__column__body"]}>
+        <div className={styles["model-number-spacer"]} />
+        {color_words}
       </div>
-      <div className={styles["panel__column__body"]}>{color_words}</div>
     </>
   );
 }
@@ -123,13 +132,13 @@ function GeneralFeature(props: GeneralFeatureProps) {
   );
 }
 
-function Feature(props: {data: FeatureOutput}) {
+function Feature(props: { data: FeatureOutput; hideHeaders: boolean }) {
   const data = props.data;
   switch (data.type) {
     case "image":
-      return <ImageFeature data={data} />;
+      return <ImageFeature data={data} hideHeaders={props.hideHeaders} />;
     case "text":
-      return <TextFeature data={data} />;
+      return <TextFeature data={data} hideHeaders={props.hideHeaders} />;
     case "general":
       return <GeneralFeature data={data} />;
     case "empty":
diff --git a/captum/insights/attr_vis/frontend/src/components/Filter.tsx b/captum/insights/attr_vis/frontend/src/components/Filter.tsx
index a9c8de28ee..25e235b6ec 100644
--- a/captum/insights/attr_vis/frontend/src/components/Filter.tsx
+++ b/captum/insights/attr_vis/frontend/src/components/Filter.tsx
@@ -3,11 +3,14 @@ import { StringArgument, EnumArgument, NumberArgument } from "./Arguments";
 import cx from "../utils/cx";
 import styles from "../App.module.css";
 import ClassFilter from "./ClassFilter";
-import { MethodsArguments, ArgumentConfig, ArgumentType } from "../models/insightsConfig";
+import {
+  MethodsArguments,
+  ArgumentConfig,
+  ArgumentType,
+} from "../models/insightsConfig";
 import { TagClass } from "../models/filter";
 import { UserInputField } from "../models/typeHelpers";
 
-
 interface FilterProps {
   prediction: string;
   selectedMethod: string;
diff --git a/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx b/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx
index 50ec05d329..52871ce67a 100644
--- a/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx
+++ b/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx
@@ -28,7 +28,10 @@ interface FilterContainerState {
   method_arguments: MethodsArguments;
 }
 
-class FilterContainer extends React.Component<FilterContainerProps, FilterContainerState> {
+class FilterContainer extends React.Component<
+  FilterContainerProps,
+  FilterContainerState
+> {
   constructor(props: FilterContainerProps) {
     super(props);
     const suggested_classes = props.config.classes.map((c, classId) => ({
@@ -47,7 +50,10 @@ class FilterContainer extends React.Component<FilterContainerProps, FilterContai
   handleClassDelete = (classId: number) => {
     const classes = this.state.classes.slice(0);
     const removed_class = classes.splice(classId, 1);
-    const suggested_classes = [...this.state.suggested_classes, ...removed_class];
+    const suggested_classes = [
+      ...this.state.suggested_classes,
+      ...removed_class,
+    ];
     this.setState({ classes, suggested_classes });
   };
 
diff --git a/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx b/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx
index 9c8a9aa113..68589be021 100644
--- a/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx
+++ b/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx
@@ -4,15 +4,20 @@ import styles from "../App.module.css";
 
 interface LabelButtonProps {
   labelIndex: number;
-  instance: number;
+  inputIndex: number;
+  modelIndex: number;
   active: boolean;
-  onTargetClick: (labelIndex: number, instance: number) => void;
+  onTargetClick: (
+    labelIndex: number,
+    inputIndex: number,
+    modelIndex: number
+  ) => void;
 }
 
 function LabelButton(props: React.PropsWithChildren<LabelButtonProps>) {
   const onClick = (e: React.MouseEvent<HTMLButtonElement>) => {
     e.preventDefault();
-    props.onTargetClick(props.labelIndex, props.instance);
+    props.onTargetClick(props.labelIndex, props.inputIndex, props.modelIndex);
   };
 
   return (
diff --git a/captum/insights/attr_vis/frontend/src/components/Visualization.tsx b/captum/insights/attr_vis/frontend/src/components/Visualization.tsx
index c0539a4f36..c70b612611 100644
--- a/captum/insights/attr_vis/frontend/src/components/Visualization.tsx
+++ b/captum/insights/attr_vis/frontend/src/components/Visualization.tsx
@@ -10,14 +10,22 @@ import { VisualizationOutput } from "../models/visualizationOutput";
 interface VisualizationProps {
   data: VisualizationOutput;
   instance: number;
-  onTargetClick: (labelIndex: number, instance: number, callback: () => void) => void;
+  onTargetClick: (
+    labelIndex: number,
+    inputIndex: number,
+    modelIndex: number,
+    callback: () => void
+  ) => void;
 }
 
 interface VisualizationState {
   loading: boolean;
 }
 
-class Visualization extends React.Component<VisualizationProps, VisualizationState> {
+class Visualization extends React.Component<
+  VisualizationProps,
+  VisualizationState
+> {
   constructor(props: VisualizationProps) {
     super(props);
     this.state = {
@@ -25,17 +33,25 @@ class Visualization extends React.Component<VisualizationProps, VisualizationSta
     };
   }
 
-  onTargetClick = (labelIndex: number, instance: number) => {
+  onTargetClick = (
+    labelIndex: number,
+    inputIndex: number,
+    modelIndex: number
+  ) => {
     this.setState({ loading: true });
-    this.props.onTargetClick(labelIndex, instance, () =>
+    this.props.onTargetClick(labelIndex, inputIndex, modelIndex, () =>
       this.setState({ loading: false })
     );
   };
 
+  //TODO: Refactor the visualization table as a <table> instead of columns, in order to have cleaner styling
   render() {
     const data = this.props.data;
+    const isFirstInGroup = this.props.data.model_index == 0;
     console.log("visualization", data);
-    const features = data.feature_outputs.map((f) => <Feature data={f} />);
+    const features = data.feature_outputs.map((f) => (
+      <Feature data={f} hideHeaders={isFirstInGroup} />
+    ));
 
     return (
       <>
@@ -44,22 +60,23 @@ class Visualization extends React.Component<VisualizationProps, VisualizationSta
             <Spinner />
           </div>
         )}
-        <div
-          className={cx({
-            [styles.panel]: true,
-            [styles["panel--long"]]: true,
-            [styles["panel--loading"]]: this.state.loading,
-          })}
-        >
+        {!isFirstInGroup && <div className={styles["model-separator"]} />}
+        <div className={styles["visualization-container"]}>
           <div className={styles["panel__column"]}>
-            <div className={styles["panel__column__title"]}>Predicted</div>
+            {isFirstInGroup && (
+              <div className={styles["panel__column__title"]}>Predicted</div>
+            )}
             <div className={styles["panel__column__body"]}>
+              <div className={styles["model-number"]}>
+                Model {data.model_index + 1}
+              </div>
               {data.predicted.map((p) => (
                 <div className={cx([styles.row, styles["row--padding"]])}>
                   <LabelButton
                     onTargetClick={this.onTargetClick}
                     labelIndex={p.index}
-                    instance={this.props.instance}
+                    inputIndex={this.props.instance}
+                    modelIndex={this.props.data.model_index}
                     active={p.index === data.active_index}
                   >
                     {p.label} ({p.score.toFixed(3)})
@@ -69,13 +86,17 @@ class Visualization extends React.Component<VisualizationProps, VisualizationSta
             </div>
           </div>
           <div className={styles["panel__column"]}>
-            <div className={styles["panel__column__title"]}>Label</div>
+            {isFirstInGroup && (
+              <div className={styles["panel__column__title"]}>Label</div>
+            )}
             <div className={styles["panel__column__body"]}>
+              <div className={styles["model-number-spacer"]} />
               <div className={cx([styles.row, styles["row--padding"]])}>
                 <LabelButton
                   onTargetClick={this.onTargetClick}
                   labelIndex={data.actual.index}
-                  instance={this.props.instance}
+                  inputIndex={this.props.instance}
+                  modelIndex={this.props.data.model_index}
                   active={data.actual.index === data.active_index}
                 >
                   {data.actual.label}
@@ -84,8 +105,11 @@ class Visualization extends React.Component<VisualizationProps, VisualizationSta
             </div>
           </div>
           <div className={styles["panel__column"]}>
-            <div className={styles["panel__column__title"]}>Contribution</div>
+            {isFirstInGroup && (
+              <div className={styles["panel__column__title"]}>Contribution</div>
+            )}
             <div className={styles["panel__column__body"]}>
+              <div className={styles["model-number-spacer"]} />
               <div className={styles["bar-chart"]}>
                 <Contributions feature_outputs={data.feature_outputs} />
               </div>
diff --git a/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx b/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx
new file mode 100644
index 0000000000..023699413f
--- /dev/null
+++ b/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx
@@ -0,0 +1,38 @@
+import React from "react";
+import styles from "../App.module.css";
+import cx from "../utils/cx";
+import Visualization from "../components/Visualization";
+import { VisualizationGroup } from "../models/visualizationOutput";
+
+interface VisualizationGroupDisplayProps {
+  inputIndex: number;
+  data: VisualizationGroup;
+  onTargetClick: (
+    labelIndex: number,
+    inputIndex: number,
+    modelIndex: number,
+    callback: () => void
+  ) => void;
+}
+
+function VisualizationGroupDisplay(props: VisualizationGroupDisplayProps) {
+  return (
+    <div
+      className={cx({
+        [styles.panel]: true,
+        [styles["panel--long"]]: true,
+      })}
+    >
+      {props.data.map((v, i) => (
+        <Visualization
+          data={v}
+          instance={props.inputIndex}
+          onTargetClick={props.onTargetClick}
+          key={i}
+        />
+      ))}
+    </div>
+  );
+}
+
+export default VisualizationGroupDisplay;
diff --git a/captum/insights/attr_vis/frontend/src/components/plotly.module.d.ts b/captum/insights/attr_vis/frontend/src/components/plotly.module.d.ts
index ff6718a43c..d12dc32f0e 100644
--- a/captum/insights/attr_vis/frontend/src/components/plotly.module.d.ts
+++ b/captum/insights/attr_vis/frontend/src/components/plotly.module.d.ts
@@ -1 +1 @@
-declare module 'plotly.js-basic-dist';
\ No newline at end of file
+declare module "plotly.js-basic-dist";
diff --git a/captum/insights/attr_vis/frontend/src/models/filter.ts b/captum/insights/attr_vis/frontend/src/models/filter.ts
index 9044bb2ba4..4002ccd44b 100644
--- a/captum/insights/attr_vis/frontend/src/models/filter.ts
+++ b/captum/insights/attr_vis/frontend/src/models/filter.ts
@@ -1,11 +1,11 @@
 export interface FilterConfig {
-    attribution_method: string;
-    arguments: { [key: string]: any};
-    prediction: string;
-    classes: string[];
+  attribution_method: string;
+  arguments: { [key: string]: any };
+  prediction: string;
+  classes: string[];
 }
 
 export interface TagClass {
-    id: number;
-    name: string;
+  id: number;
+  name: string;
 }
diff --git a/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts b/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts
index 08687ff0c0..54afb547ee 100644
--- a/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts
+++ b/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts
@@ -1,30 +1,30 @@
 export enum ArgumentType {
-    Number = "number",
-    Enum = "enum",
-    String = "string",
-    Boolean = "boolean"
+  Number = "number",
+  Enum = "enum",
+  String = "string",
+  Boolean = "boolean",
 }
 
 export type GenericArgumentConfig<T> = {
-    value: T;
-    limit: T[];
-}
+  value: T;
+  limit: T[];
+};
 
 export type ArgumentConfig =
-    { type: ArgumentType.Number } & GenericArgumentConfig<number> |
-    { type: ArgumentType.Enum } & GenericArgumentConfig<string> |
-    { type: ArgumentType.String } & { value: string } |
-    { type: ArgumentType.Boolean } & { value: boolean }
+  | ({ type: ArgumentType.Number } & GenericArgumentConfig<number>)
+  | ({ type: ArgumentType.Enum } & GenericArgumentConfig<string>)
+  | ({ type: ArgumentType.String } & { value: string })
+  | ({ type: ArgumentType.Boolean } & { value: boolean });
 
 export interface MethodsArguments {
-    [method_name: string]: {
-        [arg_name: string]: ArgumentConfig;
-    }
+  [method_name: string]: {
+    [arg_name: string]: ArgumentConfig;
+  };
 }
 
 export interface InsightsConfig {
-    classes: string[];
-    methods: string[];
-    method_arguments: MethodsArguments;
-    selected_method: string;
-}
\ No newline at end of file
+  classes: string[];
+  methods: string[];
+  method_arguments: MethodsArguments;
+  selected_method: string;
+}
diff --git a/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts b/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts
index ad17002b97..b1386f36ef 100644
--- a/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts
+++ b/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts
@@ -1 +1 @@
-export type UserInputField = HTMLInputElement | HTMLSelectElement;
\ No newline at end of file
+export type UserInputField = HTMLInputElement | HTMLSelectElement;
diff --git a/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts b/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts
index 26fb46b94d..923d5e5438 100644
--- a/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts
+++ b/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts
@@ -1,30 +1,41 @@
 interface OutputScore {
-    label: string;
-    index: number;
-    score: number;
+  label: string;
+  index: number;
+  score: number;
 }
 
 export enum FeatureType {
-    TEXT = "text",
-    IMAGE = "image",
-    GENERAL = "general",
-    EMPTY = "empty"
+  TEXT = "text",
+  IMAGE = "image",
+  GENERAL = "general",
+  EMPTY = "empty",
 }
 
 type GenericFeatureOutput<F extends FeatureType, T> = {
-    type: F,
-    name: string,
-    contribution: number
+  type: F;
+  name: string;
+  contribution: number;
 } & T;
 
-export type FeatureOutput = GenericFeatureOutput<FeatureType.TEXT, { base: number[], modified: number[] }>
-    | GenericFeatureOutput<FeatureType.IMAGE, { base: string, modified: string }>
-    | GenericFeatureOutput<FeatureType.GENERAL, { base: number[], modified: number[]}>
-    | GenericFeatureOutput<FeatureType.EMPTY, {}>
+export type FeatureOutput =
+  | GenericFeatureOutput<
+      FeatureType.TEXT,
+      { base: number[]; modified: number[] }
+    >
+  | GenericFeatureOutput<FeatureType.IMAGE, { base: string; modified: string }>
+  | GenericFeatureOutput<
+      FeatureType.GENERAL,
+      { base: number[]; modified: number[] }
+    >
+  | GenericFeatureOutput<FeatureType.EMPTY, {}>;
 
 export interface VisualizationOutput {
-    feature_outputs: FeatureOutput[];
-    actual: OutputScore;
-    predicted: OutputScore[];
-    active_index: number;
+  model_index: number;
+  feature_outputs: FeatureOutput[];
+  actual: OutputScore;
+  predicted: OutputScore[];
+  active_index: number;
 }
+
+//When multiple models are compared, visualizations are grouped together
+export type VisualizationGroup = VisualizationOutput[];
diff --git a/captum/insights/attr_vis/frontend/tsconfig.json b/captum/insights/attr_vis/frontend/tsconfig.json
index f2850b7161..af10394b4c 100644
--- a/captum/insights/attr_vis/frontend/tsconfig.json
+++ b/captum/insights/attr_vis/frontend/tsconfig.json
@@ -1,11 +1,7 @@
 {
   "compilerOptions": {
     "target": "es5",
-    "lib": [
-      "dom",
-      "dom.iterable",
-      "esnext"
-    ],
+    "lib": ["dom", "dom.iterable", "esnext"],
     "allowJs": true,
     "skipLibCheck": true,
     "esModuleInterop": true,
@@ -19,7 +15,5 @@
     "noEmit": true,
     "jsx": "react"
   },
-  "include": [
-    "src"
-  ]
+  "include": ["src"]
 }
diff --git a/captum/insights/attr_vis/frontend/widget/src/extension.js b/captum/insights/attr_vis/frontend/widget/src/extension.js
index 3ee405bd23..4a5214baeb 100644
--- a/captum/insights/attr_vis/frontend/widget/src/extension.js
+++ b/captum/insights/attr_vis/frontend/widget/src/extension.js
@@ -14,13 +14,13 @@ if (window.require) {
   window.require.config({
     map: {
       "*": {
-        "jupyter-captum-insights": "nbextensions/jupyter-captum-insights/index"
-      }
-    }
+        "jupyter-captum-insights": "nbextensions/jupyter-captum-insights/index",
+      },
+    },
   });
 }
 
 // Export the required load_ipython_extension
 module.exports = {
-  load_ipython_extension: function() {}
+  load_ipython_extension: function () {},
 };
diff --git a/captum/insights/attr_vis/frontend/widget/webpack.config.js b/captum/insights/attr_vis/frontend/widget/webpack.config.js
index 5937dc8a23..c292774dbc 100644
--- a/captum/insights/attr_vis/frontend/widget/webpack.config.js
+++ b/captum/insights/attr_vis/frontend/widget/webpack.config.js
@@ -1,118 +1,126 @@
-var path = require('path');
-var version = require('../package.json').version;
+var path = require("path");
+var version = require("../package.json").version;
 
 // Custom webpack rules are generally the same for all webpack bundles, hence
 // stored in a separate local variable.
 var rules = [
-    { test: /\.module.css$/, use: [
-      'style-loader',
-        {
-          loader: 'css-loader',
-          options: {
-            modules: true
-          }
+  {
+    test: /\.module.css$/,
+    use: [
+      "style-loader",
+      {
+        loader: "css-loader",
+        options: {
+          modules: true,
         },
-      ]
-    },
-    { test: /^((?!\.module).)*.css$/, use: ['style-loader', 'css-loader'] },
-    {
-      test: /\.(js|ts|tsx)$/,
-      exclude: /node_modules/,
-      loaders: 'babel-loader',
-      options: {
-         presets: ['@babel/preset-react', '@babel/preset-env', '@babel/preset-typescript'],
-         plugins: [
-            "@babel/plugin-proposal-class-properties"
-         ],
       },
-    }
-]
+    ],
+  },
+  { test: /^((?!\.module).)*.css$/, use: ["style-loader", "css-loader"] },
+  {
+    test: /\.(js|ts|tsx)$/,
+    exclude: /node_modules/,
+    loaders: "babel-loader",
+    options: {
+      presets: [
+        "@babel/preset-react",
+        "@babel/preset-env",
+        "@babel/preset-typescript",
+      ],
+      plugins: ["@babel/plugin-proposal-class-properties"],
+    },
+  },
+];
 
-var extensions = ['.js', '.ts', '.tsx']
+var extensions = [".js", ".ts", ".tsx"];
 
 module.exports = [
-    {// Notebook extension
-     //
-     // This bundle only contains the part of the JavaScript that is run on
-     // load of the notebook. This section generally only performs
-     // some configuration for requirejs, and provides the legacy
-     // "load_ipython_extension" function which is required for any notebook
-     // extension.
-     //
-        mode: 'production',
-        entry: './src/extension.js',
-        output: {
-            filename: 'extension.js',
-            path: path.resolve(__dirname, '..', '..', 'widget', 'static'),
-            libraryTarget: 'amd'
-        },
-        resolveLoader: {
-          modules: ['../node_modules'],
-          extensions: extensions
-        },
-        resolve: {
-          modules: ['../node_modules']
-        },
+  {
+    // Notebook extension
+    //
+    // This bundle only contains the part of the JavaScript that is run on
+    // load of the notebook. This section generally only performs
+    // some configuration for requirejs, and provides the legacy
+    // "load_ipython_extension" function which is required for any notebook
+    // extension.
+    //
+    mode: "production",
+    entry: "./src/extension.js",
+    output: {
+      filename: "extension.js",
+      path: path.resolve(__dirname, "..", "..", "widget", "static"),
+      libraryTarget: "amd",
     },
-    {// Bundle for the notebook containing the custom widget views and models
-     //
-     // This bundle contains the implementation for the custom widget views and
-     // custom widget.
-     // It must be an amd module
-     //
-        mode: 'production',
-        entry: './src/index.js',
-        output: {
-            filename: 'index.js',
-            path: path.resolve(__dirname, '..', '..', 'widget', 'static'),
-            libraryTarget: 'amd'
-        },
-        devtool: 'source-map',
-        module: {
-            rules: rules,
-        },
-        resolveLoader: {
-          modules: ['../node_modules']
-        },
-        resolve: {
-          modules: ['../node_modules'],
-          extensions: extensions
-        },
-        externals: ['@jupyter-widgets/base']
+    resolveLoader: {
+      modules: ["../node_modules"],
+      extensions: extensions,
     },
-    {// Embeddable jupyter-captum-insights bundle
-     //
-     // This bundle is generally almost identical to the notebook bundle
-     // containing the custom widget views and models.
-     //
-     // The only difference is in the configuration of the webpack public path
-     // for the static assets.
-     //
-     // It will be automatically distributed by unpkg to work with the static
-     // widget embedder.
-     //
-     // The target bundle is always `dist/index.js`, which is the path required
-     // by the custom widget embedder.
-     //
-        mode: 'production',
-        entry: './src/embed.js',
-        output: {
-            filename: 'index.js',
-            path: path.resolve(__dirname, '..', '..', 'widget', 'dist'),
-            libraryTarget: 'amd',
-            publicPath: 'https://unpkg.com/jupyter-captum-insights@' + version + '/dist/'
-        },
-        devtool: 'source-map',
-        module: {
-            rules: rules
-        },
-        resolveLoader: {
-          modules: ['../node_modules']
-        },
-        resolve: {
-          modules: ['../node_modules'],
-          extensions: extensions
-        },
-        externals: ['@jupyter-widgets/base']
-    }
+    resolve: {
+      modules: ["../node_modules"],
+    },
+  },
+  {
+    // Bundle for the notebook containing the custom widget views and models
+    //
+    // This bundle contains the implementation for the custom widget views and
+    // custom widget.
+    // It must be an amd module
+    //
+    mode: "production",
+    entry: "./src/index.js",
+    output: {
+      filename: "index.js",
+      path: path.resolve(__dirname, "..", "..", "widget", "static"),
+      libraryTarget: "amd",
+    },
+    devtool: "source-map",
+    module: {
+      rules: rules,
+    },
+    resolveLoader: {
+      modules: ["../node_modules"],
+    },
+    resolve: {
+      modules: ["../node_modules"],
+      extensions: extensions,
+    },
+    externals: ["@jupyter-widgets/base"],
+  },
+  {
+    // Embeddable jupyter-captum-insights bundle
+    //
+    // This bundle is generally almost identical to the notebook bundle
+    // containing the custom widget views and models.
+    //
+    // The only difference is in the configuration of the webpack public path
+    // for the static assets.
+    //
+    // It will be automatically distributed by unpkg to work with the static
+    // widget embedder.
+    //
+    // The target bundle is always `dist/index.js`, which is the path required
+    // by the custom widget embedder.
+    //
+    mode: "production",
+    entry: "./src/embed.js",
+    output: {
+      filename: "index.js",
+      path: path.resolve(__dirname, "..", "..", "widget", "dist"),
+      libraryTarget: "amd",
+      publicPath:
+        "https://unpkg.com/jupyter-captum-insights@" + version + "/dist/",
+    },
+    devtool: "source-map",
+    module: {
+      rules: rules,
+    },
+    resolveLoader: {
+      modules: ["../node_modules"],
+    },
+    resolve: {
+      modules: ["../node_modules"],
+      extensions: extensions,
+    },
+    externals: ["@jupyter-widgets/base"],
+  },
 ];
diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py
index 15cee3906d..17bfd8e34f 100644
--- a/captum/insights/attr_vis/server.py
+++ b/captum/insights/attr_vis/server.py
@@ -42,7 +42,9 @@ def attribute():
     r = request.get_json(force=True)
     return jsonify(
         namedtuple_to_dict(
-            visualizer._calculate_attribution_from_cache(r["instance"], r["labelIndex"])
+            visualizer._calculate_attribution_from_cache(
+                r["inputIndex"], r["modelIndex"], r["labelIndex"]
+            )
         )
     )
 
diff --git a/tests/insights/test_contribution.py b/tests/insights/test_contribution.py
index b9a4b01e71..0f04b59ddc 100644
--- a/tests/insights/test_contribution.py
+++ b/tests/insights/test_contribution.py
@@ -167,7 +167,7 @@ def test_one_feature(self):
         outputs = visualizer.visualize()
 
         for output in outputs:
-            total_contrib = sum(abs(f.contribution) for f in output.feature_outputs)
+            total_contrib = sum(abs(f.contribution) for f in output[0].feature_outputs)
             self.assertAlmostEqual(total_contrib, 1.0, places=6)
 
     def test_multi_features(self):
@@ -210,7 +210,7 @@ def test_multi_features(self):
         outputs = visualizer.visualize()
 
         for output in outputs:
-            total_contrib = sum(abs(f.contribution) for f in output.feature_outputs)
+            total_contrib = sum(abs(f.contribution) for f in output[0].feature_outputs)
             self.assertAlmostEqual(total_contrib, 1.0, places=6)
 
     # TODO: add test for multiple models (related to TODO in captum/insights/api.py)