Skip to content

Commit f33198b

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix NoiseTunnel pyre fixme issues
Summary: Fixing unresolved pyre fixme issues in corresponding file Differential Revision: D76738229
1 parent 20478cf commit f33198b

File tree

1 file changed

+15
-21
lines changed

1 file changed

+15
-21
lines changed

captum/attr/_core/noise_tunnel.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
from enum import Enum
5-
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
5+
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -14,7 +14,6 @@
1414
_format_tensor_into_tuples,
1515
_is_tuple,
1616
)
17-
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
1817
from captum.attr._utils.attribution import Attribution, GradientAttribution
1918
from captum.attr._utils.common import _validate_noise_tunnel_type
2019
from captum.log import log_usage
@@ -91,12 +90,10 @@ def attribute(
9190
draw_baseline_from_distrib: bool = False,
9291
**kwargs: Any,
9392
) -> Union[
94-
Union[
95-
Tensor,
96-
Tuple[Tensor, Tensor],
97-
Tuple[Tensor, ...],
98-
Tuple[Tuple[Tensor, ...], Tensor],
99-
]
93+
Tensor,
94+
Tuple[Tensor, Tensor],
95+
Tuple[Tensor, ...],
96+
Tuple[Tuple[Tensor, ...], Tensor],
10097
]:
10198
r"""
10299
Args:
@@ -298,8 +295,7 @@ def attribute(
298295
delta,
299296
)
300297

301-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
302-
def attribute_future(self) -> Callable:
298+
def attribute_future(self) -> None:
303299
r"""
304300
This method is not implemented for NoiseTunnel.
305301
"""
@@ -490,11 +486,11 @@ def _apply_checks_and_return_attributions(
490486
is_attrib_tuple: bool,
491487
return_convergence_delta: bool,
492488
delta: Union[None, Tensor],
493-
# pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <:
494-
# [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]`
495-
# isn't present in the function's parameters.
496489
) -> Union[
497-
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
490+
Tensor,
491+
Tuple[Tensor, Tensor],
492+
Tuple[Tensor, ...],
493+
Tuple[Tuple[Tensor, ...], Tensor],
498494
]:
499495
attributions_tuple = _format_output(is_attrib_tuple, attributions)
500496

@@ -503,17 +499,15 @@ def _apply_checks_and_return_attributions(
503499
if self.is_delta_supported and return_convergence_delta
504500
else attributions_tuple
505501
)
506-
ret = cast(
507-
# pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <:
508-
# [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]`
509-
# isn't present in the function's parameters.
502+
return cast(
510503
Union[
511-
TensorOrTupleOfTensorsGeneric,
512-
Tuple[TensorOrTupleOfTensorsGeneric, Tensor],
504+
Tensor,
505+
Tuple[Tensor, Tensor],
506+
Tuple[Tensor, ...],
507+
Tuple[Tuple[Tensor, ...], Tensor],
513508
],
514509
ret,
515510
)
516-
return ret
517511

518512
def has_convergence_delta(self) -> bool:
519513
return self.is_delta_supported

0 commit comments

Comments
 (0)