@@ -47,8 +47,9 @@ class FeatureAblation(PerturbationAttribution):
47
47
first dimension (i.e. a feature mask requires to be applied to all inputs).
48
48
"""
49
49
50
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
51
- def __init__ (self , forward_func : Callable ) -> None :
50
+ def __init__ (
51
+ self , forward_func : Callable [..., Union [int , float , Tensor , Future [Tensor ]]]
52
+ ) -> None :
52
53
r"""
53
54
Args:
54
55
@@ -74,8 +75,7 @@ def attribute(
74
75
inputs : TensorOrTupleOfTensorsGeneric ,
75
76
baselines : BaselineType = None ,
76
77
target : TargetType = None ,
77
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
78
- additional_forward_args : Any = None ,
78
+ additional_forward_args : object = None ,
79
79
feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
80
80
perturbations_per_eval : int = 1 ,
81
81
show_progress : bool = False ,
@@ -261,17 +261,13 @@ def attribute(
261
261
"""
262
262
# Keeps track whether original input is a tuple or not before
263
263
# converting it into a tuple.
264
- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
265
- # `TensorOrTupleOfTensorsGeneric`.
266
264
is_inputs_tuple = _is_tuple (inputs )
267
265
268
266
formatted_inputs , baselines = _format_input_baseline (inputs , baselines )
269
267
formatted_additional_forward_args = _format_additional_forward_args (
270
268
additional_forward_args
271
269
)
272
270
num_examples = formatted_inputs [0 ].shape [0 ]
273
- # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
274
- # `TensorOrTupleOfTensorsGeneric`.
275
271
formatted_feature_mask = _format_feature_mask (feature_mask , formatted_inputs )
276
272
277
273
assert (
@@ -384,8 +380,6 @@ def attribute(
384
380
# pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
385
381
# [Tensor, typing.Tuple[Tensor, ...]]]`
386
382
# but got `Union[Tensor, typing.Tuple[Tensor, ...]]`.
387
- # pyre-fixme[6]: In call `FeatureAblation._generate_result`,
388
- # for 3rd positional argument, expected `bool` but got `Literal[]`.
389
383
return self ._generate_result (total_attrib , weights , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
390
384
391
385
def _initial_eval_to_processed_initial_eval_fut (
@@ -414,8 +408,7 @@ def attribute_future(
414
408
inputs : TensorOrTupleOfTensorsGeneric ,
415
409
baselines : BaselineType = None ,
416
410
target : TargetType = None ,
417
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
418
- additional_forward_args : Any = None ,
411
+ additional_forward_args : object = None ,
419
412
feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
420
413
perturbations_per_eval : int = 1 ,
421
414
show_progress : bool = False ,
@@ -428,8 +421,6 @@ def attribute_future(
428
421
429
422
# Keeps track whether original input is a tuple or not before
430
423
# converting it into a tuple.
431
- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
432
- # `TensorOrTupleOfTensorsGeneric`.
433
424
is_inputs_tuple = _is_tuple (inputs )
434
425
formatted_inputs , baselines = _format_input_baseline (inputs , baselines )
435
426
formatted_additional_forward_args = _format_additional_forward_args (
@@ -660,13 +651,11 @@ def _eval_fut_to_ablated_out_fut(
660
651
) from e
661
652
return result
662
653
663
- # pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
664
654
def _ith_input_ablation_generator (
665
655
self ,
666
656
i : int ,
667
657
inputs : TensorOrTupleOfTensorsGeneric ,
668
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
669
- additional_args : Any ,
658
+ additional_args : object ,
670
659
target : TargetType ,
671
660
baselines : BaselineType ,
672
661
input_mask : Union [None , Tensor , Tuple [Tensor , ...]],
@@ -675,7 +664,7 @@ def _ith_input_ablation_generator(
675
664
) -> Generator [
676
665
Tuple [
677
666
Tuple [Tensor , ...],
678
- Any ,
667
+ object ,
679
668
TargetType ,
680
669
Tensor ,
681
670
],
@@ -705,10 +694,9 @@ def _ith_input_ablation_generator(
705
694
perturbations_per_eval = min (perturbations_per_eval , num_features )
706
695
baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
707
696
if isinstance (baseline , torch .Tensor ):
708
- # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]`
709
- # and `Size`.
710
- baseline = baseline .reshape ((1 ,) + baseline .shape )
697
+ baseline = baseline .reshape ((1 ,) + tuple (baseline .shape ))
711
698
699
+ additional_args_repeated : object
712
700
if perturbations_per_eval > 1 :
713
701
# Repeat features and additional args for batch size.
714
702
all_features_repeated = [
@@ -727,6 +715,7 @@ def _ith_input_ablation_generator(
727
715
target_repeated = target
728
716
729
717
num_features_processed = min_feature
718
+ current_additional_args : object
730
719
while num_features_processed < num_features :
731
720
current_num_ablated_features = min (
732
721
perturbations_per_eval , num_features - num_features_processed
@@ -762,9 +751,7 @@ def _ith_input_ablation_generator(
762
751
# dimension of this tensor.
763
752
current_reshaped = current_features [i ].reshape (
764
753
(current_num_ablated_features , - 1 )
765
- # pyre-fixme[58]: `+` is not supported for operand types
766
- # `Tuple[int, int]` and `Size`.
767
- + current_features [i ].shape [1 :]
754
+ + tuple (current_features [i ].shape [1 :])
768
755
)
769
756
770
757
ablated_features , current_mask = self ._construct_ablated_input (
@@ -780,10 +767,7 @@ def _ith_input_ablation_generator(
780
767
# (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
781
768
# which can be provided to the model as input.
782
769
current_features [i ] = ablated_features .reshape (
783
- (- 1 ,)
784
- # pyre-fixme[58]: `+` is not supported for operand types
785
- # `Tuple[int]` and `Size`.
786
- + ablated_features .shape [2 :]
770
+ (- 1 ,) + tuple (ablated_features .shape [2 :])
787
771
)
788
772
yield tuple (
789
773
current_features
@@ -818,9 +802,7 @@ def _construct_ablated_input(
818
802
thus counted towards ablations for that feature) and 0s otherwise.
819
803
"""
820
804
current_mask = torch .stack (
821
- # pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
822
- # Tuple[Tensor, ...]]` but got `List[Union[bool, Tensor]]`.
823
- [input_mask == j for j in range (start_feature , end_feature )], # type: ignore # noqa: E501 line too long
805
+ cast (List [Tensor ], [input_mask == j for j in range (start_feature , end_feature )]), # type: ignore # noqa: E501 line too long
824
806
dim = 0 ,
825
807
).long ()
826
808
current_mask = current_mask .to (expanded_input .device )
0 commit comments