Skip to content

Commit 97e2195

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in Feature Ablation (pytorch#1392)
Summary: Initial work on fixing Pyre errors in Feature Ablation Reviewed By: jjuncho Differential Revision: D64677337
1 parent 38592ae commit 97e2195

File tree

1 file changed

+13
-31
lines changed

1 file changed

+13
-31
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ class FeatureAblation(PerturbationAttribution):
4747
first dimension (i.e. a feature mask requires to be applied to all inputs).
4848
"""
4949

50-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
51-
def __init__(self, forward_func: Callable) -> None:
50+
def __init__(
51+
self, forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]]
52+
) -> None:
5253
r"""
5354
Args:
5455
@@ -74,8 +75,7 @@ def attribute(
7475
inputs: TensorOrTupleOfTensorsGeneric,
7576
baselines: BaselineType = None,
7677
target: TargetType = None,
77-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
78-
additional_forward_args: Any = None,
78+
additional_forward_args: object = None,
7979
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
8080
perturbations_per_eval: int = 1,
8181
show_progress: bool = False,
@@ -261,17 +261,13 @@ def attribute(
261261
"""
262262
# Keeps track whether original input is a tuple or not before
263263
# converting it into a tuple.
264-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
265-
# `TensorOrTupleOfTensorsGeneric`.
266264
is_inputs_tuple = _is_tuple(inputs)
267265

268266
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
269267
formatted_additional_forward_args = _format_additional_forward_args(
270268
additional_forward_args
271269
)
272270
num_examples = formatted_inputs[0].shape[0]
273-
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
274-
# `TensorOrTupleOfTensorsGeneric`.
275271
formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs)
276272

277273
assert (
@@ -384,8 +380,6 @@ def attribute(
384380
# pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
385381
# [Tensor, typing.Tuple[Tensor, ...]]]`
386382
# but got `Union[Tensor, typing.Tuple[Tensor, ...]]`.
387-
# pyre-fixme[6]: In call `FeatureAblation._generate_result`,
388-
# for 3rd positional argument, expected `bool` but got `Literal[]`.
389383
return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long
390384

391385
def _initial_eval_to_processed_initial_eval_fut(
@@ -414,8 +408,7 @@ def attribute_future(
414408
inputs: TensorOrTupleOfTensorsGeneric,
415409
baselines: BaselineType = None,
416410
target: TargetType = None,
417-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
418-
additional_forward_args: Any = None,
411+
additional_forward_args: object = None,
419412
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
420413
perturbations_per_eval: int = 1,
421414
show_progress: bool = False,
@@ -428,8 +421,6 @@ def attribute_future(
428421

429422
# Keeps track whether original input is a tuple or not before
430423
# converting it into a tuple.
431-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
432-
# `TensorOrTupleOfTensorsGeneric`.
433424
is_inputs_tuple = _is_tuple(inputs)
434425
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
435426
formatted_additional_forward_args = _format_additional_forward_args(
@@ -660,13 +651,11 @@ def _eval_fut_to_ablated_out_fut(
660651
) from e
661652
return result
662653

663-
# pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
664654
def _ith_input_ablation_generator(
665655
self,
666656
i: int,
667657
inputs: TensorOrTupleOfTensorsGeneric,
668-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
669-
additional_args: Any,
658+
additional_args: object,
670659
target: TargetType,
671660
baselines: BaselineType,
672661
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],
@@ -675,7 +664,7 @@ def _ith_input_ablation_generator(
675664
) -> Generator[
676665
Tuple[
677666
Tuple[Tensor, ...],
678-
Any,
667+
object,
679668
TargetType,
680669
Tensor,
681670
],
@@ -705,10 +694,9 @@ def _ith_input_ablation_generator(
705694
perturbations_per_eval = min(perturbations_per_eval, num_features)
706695
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
707696
if isinstance(baseline, torch.Tensor):
708-
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]`
709-
# and `Size`.
710-
baseline = baseline.reshape((1,) + baseline.shape)
697+
baseline = baseline.reshape((1,) + tuple(baseline.shape))
711698

699+
additional_args_repeated: object
712700
if perturbations_per_eval > 1:
713701
# Repeat features and additional args for batch size.
714702
all_features_repeated = [
@@ -727,6 +715,7 @@ def _ith_input_ablation_generator(
727715
target_repeated = target
728716

729717
num_features_processed = min_feature
718+
current_additional_args: object
730719
while num_features_processed < num_features:
731720
current_num_ablated_features = min(
732721
perturbations_per_eval, num_features - num_features_processed
@@ -762,9 +751,7 @@ def _ith_input_ablation_generator(
762751
# dimension of this tensor.
763752
current_reshaped = current_features[i].reshape(
764753
(current_num_ablated_features, -1)
765-
# pyre-fixme[58]: `+` is not supported for operand types
766-
# `Tuple[int, int]` and `Size`.
767-
+ current_features[i].shape[1:]
754+
+ tuple(current_features[i].shape[1:])
768755
)
769756

770757
ablated_features, current_mask = self._construct_ablated_input(
@@ -780,10 +767,7 @@ def _ith_input_ablation_generator(
780767
# (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
781768
# which can be provided to the model as input.
782769
current_features[i] = ablated_features.reshape(
783-
(-1,)
784-
# pyre-fixme[58]: `+` is not supported for operand types
785-
# `Tuple[int]` and `Size`.
786-
+ ablated_features.shape[2:]
770+
(-1,) + tuple(ablated_features.shape[2:])
787771
)
788772
yield tuple(
789773
current_features
@@ -818,9 +802,7 @@ def _construct_ablated_input(
818802
thus counted towards ablations for that feature) and 0s otherwise.
819803
"""
820804
current_mask = torch.stack(
821-
# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
822-
# Tuple[Tensor, ...]]` but got `List[Union[bool, Tensor]]`.
823-
[input_mask == j for j in range(start_feature, end_feature)], # type: ignore # noqa: E501 line too long
805+
cast(List[Tensor], [input_mask == j for j in range(start_feature, end_feature)]), # type: ignore # noqa: E501 line too long
824806
dim=0,
825807
).long()
826808
current_mask = current_mask.to(expanded_input.device)

0 commit comments

Comments
 (0)