@@ -37,8 +37,7 @@ class Occlusion(FeatureAblation):
37
37
/tensorflow/methods.py#L401
38
38
"""
39
39
40
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
41
- def __init__ (self , forward_func : Callable ) -> None :
40
+ def __init__ (self , forward_func : Callable [..., Tensor ]) -> None :
42
41
r"""
43
42
Args:
44
43
@@ -58,8 +57,7 @@ def attribute( # type: ignore
58
57
] = None ,
59
58
baselines : BaselineType = None ,
60
59
target : TargetType = None ,
61
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
62
- additional_forward_args : Any = None ,
60
+ additional_forward_args : object = None ,
63
61
perturbations_per_eval : int = 1 ,
64
62
show_progress : bool = False ,
65
63
) -> TensorOrTupleOfTensorsGeneric :
@@ -377,9 +375,7 @@ def _occlusion_mask(
377
375
padded_tensor = torch .nn .functional .pad (
378
376
sliding_window_tsr , tuple (pad_values ) # type: ignore
379
377
)
380
- # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` and
381
- # `Size`.
382
- return padded_tensor .reshape ((1 ,) + padded_tensor .shape )
378
+ return padded_tensor .reshape ((1 ,) + tuple (padded_tensor .shape ))
383
379
384
380
def _get_feature_range_and_mask (
385
381
self , input : Tensor , input_mask : Optional [Tensor ], ** kwargs : Any
@@ -389,8 +385,7 @@ def _get_feature_range_and_mask(
389
385
390
386
def _get_feature_counts (
391
387
self ,
392
- # pyre-fixme[2]: Parameter must be annotated.
393
- inputs ,
388
+ inputs : TensorOrTupleOfTensorsGeneric ,
394
389
feature_mask : Tuple [Tensor , ...],
395
390
** kwargs : Any ,
396
391
) -> Tuple [int , ...]:
0 commit comments