@@ -47,15 +47,13 @@ def default_perturb_func(
47
47
original inputs.
48
48
49
49
"""
50
- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
51
- # `Tuple[Tensor, ...]`.
52
- inputs = _format_tensor_into_tuples (inputs )
50
+ inputs_tuple = _format_tensor_into_tuples (inputs )
53
51
perturbed_input = tuple (
54
52
input
55
53
+ torch .FloatTensor (input .size ()) # type: ignore
56
54
.uniform_ (- perturb_radius , perturb_radius )
57
55
.to (input .device )
58
- for input in inputs
56
+ for input in inputs_tuple
59
57
)
60
58
return perturbed_input
61
59
@@ -64,8 +62,9 @@ def default_perturb_func(
64
62
def sensitivity_max (
65
63
explanation_func : Callable [..., TensorOrTupleOfTensorsGeneric ],
66
64
inputs : TensorOrTupleOfTensorsGeneric ,
67
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
68
- perturb_func : Callable = default_perturb_func ,
65
+ perturb_func : Callable [
66
+ ..., Union [Tensor , Tuple [Tensor , ...]]
67
+ ] = default_perturb_func ,
69
68
perturb_radius : float = 0.02 ,
70
69
n_perturb_samples : int = 10 ,
71
70
norm_ord : str = "fro" ,
@@ -202,7 +201,7 @@ def sensitivity_max(
202
201
203
202
def _generate_perturbations (
204
203
current_n_perturb_samples : int ,
205
- ) -> TensorOrTupleOfTensorsGeneric :
204
+ ) -> Union [ Tensor , Tuple [ Tensor , ...]] :
206
205
r"""
207
206
The perturbations are generated for each example
208
207
`current_n_perturb_samples` times.
@@ -228,8 +227,7 @@ def max_values(input_tnsr: Tensor) -> Tensor:
228
227
return torch .max (input_tnsr , dim = 1 ).values # type: ignore
229
228
230
229
kwarg_expanded_for = None
231
- # pyre-fixme[33]: Given annotation cannot be `Any`.
232
- kwargs_copy : Any = None
230
+ kwargs_copy : object = None
233
231
234
232
def _next_sensitivity_max (current_n_perturb_samples : int ) -> Tensor :
235
233
inputs_perturbed = _generate_perturbations (current_n_perturb_samples )
@@ -254,9 +252,7 @@ def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor:
254
252
)
255
253
if (
256
254
isinstance (baselines [0 ], Tensor )
257
- # pyre-fixme[16]: Item `float` of `Union[float, int, Tensor]`
258
- # has no attribute `shape`.
259
- and baselines [0 ].shape == inputs [0 ].shape
255
+ and cast (Tensor , baselines [0 ]).shape == inputs [0 ].shape
260
256
):
261
257
_expand_and_update_baselines (
262
258
cast (Tuple [Tensor , ...], inputs ),
0 commit comments