|
3 | 3 | # pyre-strict
|
4 | 4 | import warnings
|
5 | 5 | from enum import Enum
|
6 |
| -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
| 6 | +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union |
7 | 7 |
|
8 | 8 | import matplotlib
|
9 | 9 |
|
@@ -444,7 +444,7 @@ def visualize_image_attr_multiple(
|
444 | 444 | fig_size: Tuple[int, int] = (8, 6),
|
445 | 445 | use_pyplot: bool = True,
|
446 | 446 | **kwargs: Any,
|
447 |
| -) -> Tuple[Figure, Axes]: |
| 447 | +) -> Tuple[Figure, Union[Axes, List[Axes]]]: |
448 | 448 | r"""
|
449 | 449 | Visualizes attribution using multiple visualization methods displayed
|
450 | 450 | in a 1 x k grid, where k is the number of desired visualizations.
|
@@ -516,15 +516,19 @@ def visualize_image_attr_multiple(
|
516 | 516 | plt_fig = plt.figure(figsize=fig_size)
|
517 | 517 | else:
|
518 | 518 | plt_fig = Figure(figsize=fig_size)
|
519 |
| - plt_axis = plt_fig.subplots(1, len(methods)) |
| 519 | + plt_axis_np = plt_fig.subplots(1, len(methods), squeeze=True) |
520 | 520 |
|
| 521 | + plt_axis: Union[Axes, List[Axes]] |
521 | 522 | plt_axis_list: List[Axes] = []
|
522 | 523 | # When visualizing one
|
523 | 524 | if len(methods) == 1:
|
524 |
| - plt_axis_list = [plt_axis] # type: ignore |
| 525 | + plt_axis = cast(Axes, plt_axis_np) |
| 526 | + plt_axis_list = [plt_axis] |
525 | 527 | # Figure.subplots returns Axes or array of Axes
|
526 | 528 | else:
|
527 |
| - plt_axis_list = plt_axis # type: ignore |
| 529 | + # https://github.com/numpy/numpy/issues/24738 |
| 530 | + plt_axis = cast(List[Axes], cast(npt.NDArray, plt_axis_np).tolist()) |
| 531 | + plt_axis_list = plt_axis |
528 | 532 | # Figure.subplots returns Axes or array of Axes
|
529 | 533 |
|
530 | 534 | for i in range(len(methods)):
|
|
0 commit comments