Skip to content

Commit e492cec

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Fix mypy issue in visualization.py
Summary: visualize_image_attr_multiple can return a List[Axes], adds proper annotations to satisfy mypy Why casting is necessary: numpy/numpy#24738 https://github.com/matplotlib/matplotlib/blob/v3.9.2/lib/matplotlib/pyplot.py#L1583C41-L1584C1 Differential Revision: D64998799
1 parent f4b7a1f commit e492cec

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

captum/attr/_utils/visualization.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
import warnings
55
from enum import Enum
6-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6+
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
77

88
import matplotlib
99

@@ -444,7 +444,7 @@ def visualize_image_attr_multiple(
444444
fig_size: Tuple[int, int] = (8, 6),
445445
use_pyplot: bool = True,
446446
**kwargs: Any,
447-
) -> Tuple[Figure, Axes]:
447+
) -> Tuple[Figure, Union[Axes, List[Axes]]]:
448448
r"""
449449
Visualizes attribution using multiple visualization methods displayed
450450
in a 1 x k grid, where k is the number of desired visualizations.
@@ -516,15 +516,19 @@ def visualize_image_attr_multiple(
516516
plt_fig = plt.figure(figsize=fig_size)
517517
else:
518518
plt_fig = Figure(figsize=fig_size)
519-
plt_axis = plt_fig.subplots(1, len(methods))
519+
plt_axis_np = plt_fig.subplots(1, len(methods), squeeze=True)
520520

521+
plt_axis: Union[Axes, List[Axes]]
521522
plt_axis_list: List[Axes] = []
522523
# When visualizing one
523524
if len(methods) == 1:
524-
plt_axis_list = [plt_axis] # type: ignore
525+
plt_axis = cast(Axes, plt_axis_np)
526+
plt_axis_list = [plt_axis]
525527
# Figure.subplots returns Axes or array of Axes
526528
else:
527-
plt_axis_list = plt_axis # type: ignore
529+
# https://github.com/numpy/numpy/issues/24738
530+
plt_axis = cast(List[Axes], cast(npt.NDArray, plt_axis_np).tolist())
531+
plt_axis_list = plt_axis
528532
# Figure.subplots returns Axes or array of Axes
529533

530534
for i in range(len(methods)):

0 commit comments

Comments
 (0)