|
2 | 2 |
|
3 | 3 | # pyre-strict
|
4 | 4 | import typing
|
5 |
| -from typing import Callable, List, Literal, Optional, Tuple, Union |
| 5 | +from typing import Callable, cast, List, Literal, Optional, Tuple, Union |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | from captum._utils.common import (
|
@@ -301,16 +301,18 @@ def attribute( # type: ignore
|
301 | 301 | additional_forward_args=additional_forward_args,
|
302 | 302 | target=target,
|
303 | 303 | )
|
304 |
| - # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... |
305 |
| - return _format_output(is_inputs_tuple, attributions), delta |
306 |
| - # pyre-fixme[7]: Expected |
307 |
| - # `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, |
308 |
| - # typing.Tuple[Tensor, ...]]], Tensor], Variable[TensorOrTupleOfTensorsGeneric |
309 |
| - # <: [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Tuple[Tensor, ...]`. |
310 |
| - return _format_output(is_inputs_tuple, attributions) |
311 |
| - |
312 |
| - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. |
313 |
| - def attribute_future(self) -> Callable: |
| 304 | + return ( |
| 305 | + cast( |
| 306 | + TensorOrTupleOfTensorsGeneric, |
| 307 | + _format_output(is_inputs_tuple, attributions), |
| 308 | + ), |
| 309 | + delta, |
| 310 | + ) |
| 311 | + return cast( |
| 312 | + TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attributions) |
| 313 | + ) |
| 314 | + |
| 315 | + def attribute_future(self) -> None: |
314 | 316 | r"""
|
315 | 317 | This method is not implemented for IntegratedGradients.
|
316 | 318 | """
|
|
0 commit comments