2
2
3
3
# pyre-strict
4
4
from enum import Enum
5
- from typing import Any , Callable , cast , Dict , List , Optional , Sequence , Tuple , Union
5
+ from typing import Any , cast , Dict , List , Optional , Sequence , Tuple , Union
6
6
7
7
import torch
8
8
from captum ._utils .common import (
14
14
_format_tensor_into_tuples ,
15
15
_is_tuple ,
16
16
)
17
- from captum ._utils .typing import TensorOrTupleOfTensorsGeneric
18
17
from captum .attr ._utils .attribution import Attribution , GradientAttribution
19
18
from captum .attr ._utils .common import _validate_noise_tunnel_type
20
19
from captum .log import log_usage
@@ -91,12 +90,10 @@ def attribute(
91
90
draw_baseline_from_distrib : bool = False ,
92
91
** kwargs : Any ,
93
92
) -> Union [
94
- Union [
95
- Tensor ,
96
- Tuple [Tensor , Tensor ],
97
- Tuple [Tensor , ...],
98
- Tuple [Tuple [Tensor , ...], Tensor ],
99
- ]
93
+ Tensor ,
94
+ Tuple [Tensor , Tensor ],
95
+ Tuple [Tensor , ...],
96
+ Tuple [Tuple [Tensor , ...], Tensor ],
100
97
]:
101
98
r"""
102
99
Args:
@@ -298,8 +295,7 @@ def attribute(
298
295
delta ,
299
296
)
300
297
301
- # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
302
- def attribute_future (self ) -> Callable :
298
+ def attribute_future (self ) -> None :
303
299
r"""
304
300
This method is not implemented for NoiseTunnel.
305
301
"""
@@ -490,11 +486,11 @@ def _apply_checks_and_return_attributions(
490
486
is_attrib_tuple : bool ,
491
487
return_convergence_delta : bool ,
492
488
delta : Union [None , Tensor ],
493
- # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <:
494
- # [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]`
495
- # isn't present in the function's parameters.
496
489
) -> Union [
497
- TensorOrTupleOfTensorsGeneric , Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]
490
+ Tensor ,
491
+ Tuple [Tensor , Tensor ],
492
+ Tuple [Tensor , ...],
493
+ Tuple [Tuple [Tensor , ...], Tensor ],
498
494
]:
499
495
attributions_tuple = _format_output (is_attrib_tuple , attributions )
500
496
@@ -503,17 +499,15 @@ def _apply_checks_and_return_attributions(
503
499
if self .is_delta_supported and return_convergence_delta
504
500
else attributions_tuple
505
501
)
506
- ret = cast (
507
- # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <:
508
- # [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]`
509
- # isn't present in the function's parameters.
502
+ return cast (
510
503
Union [
511
- TensorOrTupleOfTensorsGeneric ,
512
- Tuple [TensorOrTupleOfTensorsGeneric , Tensor ],
504
+ Tensor ,
505
+ Tuple [Tensor , Tensor ],
506
+ Tuple [Tensor , ...],
507
+ Tuple [Tuple [Tensor , ...], Tensor ],
513
508
],
514
509
ret ,
515
510
)
516
- return ret
517
511
518
512
def has_convergence_delta (self ) -> bool :
519
513
return self .is_delta_supported
0 commit comments