Skip to content

Commit d0b6196

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix sensitivity pyre fixme issues (#1599)
Summary: Pull Request resolved: #1599 Fixing unresolved pyre fixme issues in corresponding file Differential Revision: D76737670
1 parent e5f0074 commit d0b6196

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

captum/metrics/_core/sensitivity.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,13 @@ def default_perturb_func(
4747
original inputs.
4848
4949
"""
50-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
51-
# `Tuple[Tensor, ...]`.
52-
inputs = _format_tensor_into_tuples(inputs)
50+
inputs_tuple = _format_tensor_into_tuples(inputs)
5351
perturbed_input = tuple(
5452
input
5553
+ torch.FloatTensor(input.size()) # type: ignore
5654
.uniform_(-perturb_radius, perturb_radius)
5755
.to(input.device)
58-
for input in inputs
56+
for input in inputs_tuple
5957
)
6058
return perturbed_input
6159

@@ -64,8 +62,9 @@ def default_perturb_func(
6462
def sensitivity_max(
6563
explanation_func: Callable[..., TensorOrTupleOfTensorsGeneric],
6664
inputs: TensorOrTupleOfTensorsGeneric,
67-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
68-
perturb_func: Callable = default_perturb_func,
65+
perturb_func: Callable[
66+
..., Union[Tensor, Tuple[Tensor, ...]]
67+
] = default_perturb_func,
6968
perturb_radius: float = 0.02,
7069
n_perturb_samples: int = 10,
7170
norm_ord: str = "fro",
@@ -202,7 +201,7 @@ def sensitivity_max(
202201

203202
def _generate_perturbations(
204203
current_n_perturb_samples: int,
205-
) -> TensorOrTupleOfTensorsGeneric:
204+
) -> Union[Tensor, Tuple[Tensor, ...]]:
206205
r"""
207206
The perturbations are generated for each example
208207
`current_n_perturb_samples` times.
@@ -228,8 +227,7 @@ def max_values(input_tnsr: Tensor) -> Tensor:
228227
return torch.max(input_tnsr, dim=1).values # type: ignore
229228

230229
kwarg_expanded_for = None
231-
# pyre-fixme[33]: Given annotation cannot be `Any`.
232-
kwargs_copy: Any = None
230+
kwargs_copy: object = None
233231

234232
def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor:
235233
inputs_perturbed = _generate_perturbations(current_n_perturb_samples)
@@ -254,9 +252,7 @@ def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor:
254252
)
255253
if (
256254
isinstance(baselines[0], Tensor)
257-
# pyre-fixme[16]: Item `float` of `Union[float, int, Tensor]`
258-
# has no attribute `shape`.
259-
and baselines[0].shape == inputs[0].shape
255+
and cast(Tensor, baselines[0]).shape == inputs[0].shape
260256
):
261257
_expand_and_update_baselines(
262258
cast(Tuple[Tensor, ...], inputs),

0 commit comments

Comments
 (0)