Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ffee56d

Browse files
craymichaelfacebook-github-bot
authored andcommittedOct 23, 2024·
Correct remaining typing.Literal imports (#1412)
Summary: Change remaining imports of `Literal` to be from `typing` library Reviewed By: vivekmig Differential Revision: D64807610
1 parent b80e488 commit ffee56d

File tree

12 files changed

+42
-143
lines changed

12 files changed

+42
-143
lines changed
 

‎captum/_utils/common.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,23 @@
55
from enum import Enum
66
from functools import reduce
77
from inspect import signature
8-
from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
cast,
12+
Dict,
13+
List,
14+
Literal,
15+
overload,
16+
Sequence,
17+
Tuple,
18+
Union,
19+
)
920

1021
import numpy as np
1122
import torch
1223
from captum._utils.typing import (
1324
BaselineType,
14-
Literal,
1525
TargetType,
1626
TensorOrTupleOfTensorsGeneric,
1727
TupleOrTensorOrBoolGeneric,
@@ -71,23 +81,17 @@ def safe_div(
7181

7282

7383
@typing.overload
74-
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
75-
# is incompatible with the return type of the implementation (`bool`).
76-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
77-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7884
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
7985

8086

8187
@typing.overload
82-
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
83-
# is incompatible with the return type of the implementation (`bool`).
84-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
85-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
8688
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
8789

8890

8991
@typing.overload
90-
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...
92+
def _is_tuple(
93+
inputs: TensorOrTupleOfTensorsGeneric,
94+
) -> bool: ... # type: ignore
9195

9296

9397
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
@@ -480,22 +484,14 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:
480484

481485

482486
@typing.overload
483-
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
484-
# possible arguments of overload defined on line `449`.
485487
def _format_output(
486-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
487-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
488488
is_inputs_tuple: Literal[True],
489489
output: Tuple[Tensor, ...],
490490
) -> Tuple[Tensor, ...]: ...
491491

492492

493493
@typing.overload
494-
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
495-
# possible arguments of overload defined on line `455`.
496494
def _format_output(
497-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
498-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
499495
is_inputs_tuple: Literal[False],
500496
output: Tuple[Tensor, ...],
501497
) -> Tensor: ...
@@ -526,22 +522,14 @@ def _format_output(
526522

527523

528524
@typing.overload
529-
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
530-
# possible arguments of overload defined on line `483`.
531525
def _format_outputs(
532-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
533-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
534526
is_multiple_inputs: Literal[False],
535527
outputs: List[Tuple[Tensor, ...]],
536528
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
537529

538530

539531
@typing.overload
540-
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
541-
# possible arguments of overload defined on line `489`.
542532
def _format_outputs(
543-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
544-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
545533
is_multiple_inputs: Literal[True],
546534
outputs: List[Tuple[Tensor, ...]],
547535
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ...

‎captum/_utils/gradient.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
import typing
66
import warnings
77
from collections import defaultdict
8-
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
cast,
12+
Dict,
13+
List,
14+
Literal,
15+
Optional,
16+
Sequence,
17+
Tuple,
18+
Union,
19+
)
920

1021
import torch
1122
from captum._utils.common import (
@@ -16,7 +27,6 @@
1627
)
1728
from captum._utils.sample_gradient import SampleGradientWrapper
1829
from captum._utils.typing import (
19-
Literal,
2030
ModuleOrModuleList,
2131
TargetType,
2232
TensorOrTupleOfTensorsGeneric,
@@ -226,9 +236,6 @@ def _forward_layer_distributed_eval(
226236
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
227237
additional_forward_args: Any = None,
228238
attribute_to_layer_input: bool = False,
229-
# pyre-fixme[9]: forward_hook_with_return has type `Literal[]`; used as `bool`.
230-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
231-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
232239
forward_hook_with_return: Literal[False] = False,
233240
require_layer_grads: bool = False,
234241
) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]: ...
@@ -246,8 +253,6 @@ def _forward_layer_distributed_eval(
246253
additional_forward_args: Any = None,
247254
attribute_to_layer_input: bool = False,
248255
*,
249-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
250-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
251256
forward_hook_with_return: Literal[True],
252257
require_layer_grads: bool = False,
253258
) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]: ...
@@ -675,7 +680,6 @@ def compute_layer_gradients_and_eval(
675680
target_ind=target_ind,
676681
additional_forward_args=additional_forward_args,
677682
attribute_to_layer_input=attribute_to_layer_input,
678-
# pyre-fixme[6]: For 7th argument expected `Literal[]` but got `bool`.
679683
forward_hook_with_return=True,
680684
require_layer_grads=True,
681685
)

‎captum/_utils/progress.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import sys
66
import warnings
77
from time import time
8-
from typing import Any, cast, Iterable, Optional, Sized, TextIO
9-
10-
from captum._utils.typing import Literal
8+
from typing import Any, cast, Iterable, Literal, Optional, Sized, TextIO
119

1210
try:
1311
from tqdm.auto import tqdm
@@ -75,10 +73,7 @@ def __enter__(self) -> "NullProgress":
7573
return self
7674

7775
# pyre-fixme[2]: Parameter must be annotated.
78-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
79-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
8076
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
81-
# pyre-fixme[7]: Expected `Literal[]` but got `bool`.
8277
return False
8378

8479
# pyre-fixme[3]: Return type must be annotated.
@@ -139,11 +134,8 @@ def __enter__(self) -> "SimpleProgress":
139134
return self
140135

141136
# pyre-fixme[2]: Parameter must be annotated.
142-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
143-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
144137
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
145138
self.close()
146-
# pyre-fixme[7]: Expected `Literal[]` but got `bool`.
147139
return False
148140

149141
# pyre-fixme[3]: Return type must be annotated.

‎captum/attr/_core/layer/layer_conductance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -12,7 +12,7 @@
1212
_format_output,
1313
)
1414
from captum._utils.gradient import compute_layer_gradients_and_eval
15-
from captum._utils.typing import BaselineType, Literal, TargetType
15+
from captum._utils.typing import BaselineType, TargetType
1616
from captum.attr._utils.approximation_methods import approximation_parameters
1717
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
1818
from captum.attr._utils.batching import _batch_attribution
@@ -86,8 +86,6 @@ def attribute(
8686
method: str = "gausslegendre",
8787
internal_batch_size: Union[None, int] = None,
8888
*,
89-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
90-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
9189
return_convergence_delta: Literal[True],
9290
attribute_to_layer_input: bool = False,
9391
grad_kwargs: Optional[Dict[str, Any]] = None,
@@ -105,9 +103,6 @@ def attribute(
105103
n_steps: int = 50,
106104
method: str = "gausslegendre",
107105
internal_batch_size: Union[None, int] = None,
108-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
109-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
110-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
111106
return_convergence_delta: Literal[False] = False,
112107
attribute_to_layer_input: bool = False,
113108
grad_kwargs: Optional[Dict[str, Any]] = None,

‎captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, Union
5+
from typing import Any, Callable, cast, Dict, Literal, Optional, Sequence, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -13,12 +13,7 @@
1313
ExpansionTypes,
1414
)
1515
from captum._utils.gradient import compute_layer_gradients_and_eval
16-
from captum._utils.typing import (
17-
BaselineType,
18-
Literal,
19-
TargetType,
20-
TensorOrTupleOfTensorsGeneric,
21-
)
16+
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
2217
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
2318
from captum.attr._utils.attribution import LayerAttribution
2419
from captum.attr._utils.common import (
@@ -101,8 +96,6 @@ def __init__(
10196

10297
# Ignoring mypy error for inconsistent signature with DeepLift
10398
@typing.overload # type: ignore
104-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
105-
# arguments of overload defined on line `117`.
10699
def attribute(
107100
self,
108101
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -111,27 +104,20 @@ def attribute(
111104
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
112105
additional_forward_args: Any = None,
113106
*,
114-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
115-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
116107
return_convergence_delta: Literal[True],
117108
attribute_to_layer_input: bool = False,
118109
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
119110
grad_kwargs: Optional[Dict[str, Any]] = None,
120111
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
121112

122113
@typing.overload
123-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
124-
# arguments of overload defined on line `104`.
125114
def attribute(
126115
self,
127116
inputs: Union[Tensor, Tuple[Tensor, ...]],
128117
baselines: BaselineType = None,
129118
target: TargetType = None,
130119
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
131120
additional_forward_args: Any = None,
132-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
133-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
134-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
135121
return_convergence_delta: Literal[False] = False,
136122
attribute_to_layer_input: bool = False,
137123
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -382,8 +368,6 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
382368
inputs,
383369
additional_forward_args,
384370
target,
385-
# pyre-fixme[31]: Expression `Literal[False])]` is not a valid type.
386-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
387371
cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
388372
)
389373

@@ -464,8 +448,6 @@ def attribute(
464448
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
465449
additional_forward_args: Any = None,
466450
*,
467-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
468-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
469451
return_convergence_delta: Literal[True],
470452
attribute_to_layer_input: bool = False,
471453
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -483,9 +465,6 @@ def attribute(
483465
target: TargetType = None,
484466
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
485467
additional_forward_args: Any = None,
486-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
487-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
488-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
489468
return_convergence_delta: Literal[False] = False,
490469
attribute_to_layer_input: bool = False,
491470
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -686,10 +665,6 @@ def attribute(
686665
target=exp_target,
687666
additional_forward_args=exp_addit_args,
688667
return_convergence_delta=cast(
689-
# pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid
690-
# type.
691-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take
692-
# parameters.
693668
Literal[True, False],
694669
return_convergence_delta,
695670
),

‎captum/attr/_core/layer/layer_gradient_shap.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
# pyre-strict
44

55
import typing
6-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union
77

88
import numpy as np
99
import torch
1010
from captum._utils.gradient import _forward_layer_eval, compute_layer_gradients_and_eval
11-
from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric
11+
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
1212
from captum.attr._core.gradient_shap import _scale_input
1313
from captum.attr._core.noise_tunnel import NoiseTunnel
1414
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
@@ -117,8 +117,6 @@ def attribute(
117117
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
118118
additional_forward_args: Any = None,
119119
*,
120-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
121-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
122120
return_convergence_delta: Literal[True],
123121
attribute_to_layer_input: bool = False,
124122
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
@@ -135,9 +133,6 @@ def attribute(
135133
stdevs: Union[float, Tuple[float, ...]] = 0.0,
136134
target: TargetType = None,
137135
additional_forward_args: Any = None,
138-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
139-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
140-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
141136
return_convergence_delta: Literal[False] = False,
142137
attribute_to_layer_input: bool = False,
143138
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
@@ -392,8 +387,6 @@ def __init__(
392387
self._multiply_by_inputs = multiply_by_inputs
393388

394389
@typing.overload
395-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
396-
# arguments of overload defined on line `385`.
397390
def attribute(
398391
self,
399392
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -402,26 +395,19 @@ def attribute(
402395
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
403396
additional_forward_args: Any = None,
404397
*,
405-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
406-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
407398
return_convergence_delta: Literal[True],
408399
attribute_to_layer_input: bool = False,
409400
grad_kwargs: Optional[Dict[str, Any]] = None,
410401
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
411402

412403
@typing.overload
413-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
414-
# arguments of overload defined on line `373`.
415404
def attribute(
416405
self,
417406
inputs: Union[Tensor, Tuple[Tensor, ...]],
418407
baselines: Union[Tensor, Tuple[Tensor, ...]],
419408
target: TargetType = None,
420409
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
421410
additional_forward_args: Any = None,
422-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
423-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
424-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
425411
return_convergence_delta: Literal[False] = False,
426412
attribute_to_layer_input: bool = False,
427413
grad_kwargs: Optional[Dict[str, Any]] = None,
@@ -505,8 +491,6 @@ def attribute( # type: ignore
505491
inputs,
506492
additional_forward_args,
507493
target,
508-
# pyre-fixme[31]: Expression `Literal[False])]` is not a valid type.
509-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
510494
cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
511495
)
512496

‎captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
import functools
55
import warnings
6-
from typing import Any, Callable, cast, List, overload, Tuple, Union
6+
from typing import Any, Callable, cast, List, Literal, overload, Tuple, Union
77

88
import torch
99
from captum._utils.common import (
@@ -12,7 +12,7 @@
1212
_format_outputs,
1313
)
1414
from captum._utils.gradient import _forward_layer_eval, _run_forward
15-
from captum._utils.typing import BaselineType, Literal, ModuleOrModuleList, TargetType
15+
from captum._utils.typing import BaselineType, ModuleOrModuleList, TargetType
1616
from captum.attr._core.integrated_gradients import IntegratedGradients
1717
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
1818
from captum.attr._utils.common import (
@@ -110,8 +110,6 @@ def __init__(
110110
)
111111

112112
@overload
113-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
114-
# arguments of overload defined on line `112`.
115113
def attribute(
116114
self,
117115
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -122,15 +120,11 @@ def attribute(
122120
n_steps: int,
123121
method: str,
124122
internal_batch_size: Union[None, int],
125-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
126-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
127123
return_convergence_delta: Literal[False],
128124
attribute_to_layer_input: bool,
129125
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ...
130126

131127
@overload
132-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
133-
# arguments of overload defined on line `126`.
134128
def attribute( # type: ignore
135129
self,
136130
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -141,8 +135,6 @@ def attribute( # type: ignore
141135
n_steps: int,
142136
method: str,
143137
internal_batch_size: Union[None, int],
144-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
145-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
146138
return_convergence_delta: Literal[True],
147139
attribute_to_layer_input: bool,
148140
) -> Tuple[

‎captum/attr/_core/layer/layer_lrp.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, cast, List, Tuple, Union
5+
from typing import Any, cast, List, Literal, Tuple, Union
66

77
from captum._utils.common import (
88
_format_tensor_into_tuples,
@@ -15,7 +15,6 @@
1515
undo_gradient_requirements,
1616
)
1717
from captum._utils.typing import (
18-
Literal,
1918
ModuleOrModuleList,
2019
TargetType,
2120
TensorOrTupleOfTensorsGeneric,
@@ -64,17 +63,13 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None:
6463
self.device_ids = cast(List[int], self.model.device_ids)
6564

6665
@typing.overload # type: ignore
67-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
68-
# arguments of overload defined on line `77`.
6966
def attribute(
7067
self,
7168
inputs: TensorOrTupleOfTensorsGeneric,
7269
target: TargetType = None,
7370
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
7471
additional_forward_args: Any = None,
7572
*,
76-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
77-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7873
return_convergence_delta: Literal[True],
7974
attribute_to_layer_input: bool = False,
8075
verbose: bool = False,
@@ -84,17 +79,12 @@ def attribute(
8479
]: ...
8580

8681
@typing.overload
87-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
88-
# arguments of overload defined on line `66`.
8982
def attribute(
9083
self,
9184
inputs: TensorOrTupleOfTensorsGeneric,
9285
target: TargetType = None,
9386
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
9487
additional_forward_args: Any = None,
95-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
96-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
97-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
9888
return_convergence_delta: Literal[False] = False,
9989
attribute_to_layer_input: bool = False,
10090
verbose: bool = False,

‎captum/attr/_utils/common.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
import typing
55
from inspect import signature
6-
from typing import Any, Callable, List, Tuple, TYPE_CHECKING, Union
6+
from typing import Any, Callable, List, Literal, Tuple, TYPE_CHECKING, Union
77

88
import torch
99
from captum._utils.common import (
@@ -12,12 +12,7 @@
1212
_format_tensor_into_tuples,
1313
_validate_input as _validate_input_basic,
1414
)
15-
from captum._utils.typing import (
16-
BaselineType,
17-
Literal,
18-
TargetType,
19-
TensorOrTupleOfTensorsGeneric,
20-
)
15+
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
2116
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS
2217
from torch import Tensor
2318

@@ -206,8 +201,6 @@ def _format_and_verify_sliding_window_shapes(
206201

207202

208203
@typing.overload
209-
# pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does
210-
# not accept all possible arguments of overload defined on line `212`.
211204
def _compute_conv_delta_and_format_attrs(
212205
attr_algo: "GradientAttribution",
213206
return_convergence_delta: bool,
@@ -217,15 +210,11 @@ def _compute_conv_delta_and_format_attrs(
217210
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
218211
additional_forward_args: Any,
219212
target: TargetType,
220-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
221-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
222213
is_inputs_tuple: Literal[True],
223214
) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: ...
224215

225216

226217
@typing.overload
227-
# pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does
228-
# not accept all possible arguments of overload defined on line `199`.
229218
def _compute_conv_delta_and_format_attrs(
230219
attr_algo: "GradientAttribution",
231220
return_convergence_delta: bool,
@@ -235,9 +224,6 @@ def _compute_conv_delta_and_format_attrs(
235224
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
236225
additional_forward_args: Any,
237226
target: TargetType,
238-
# pyre-fixme[9]: is_inputs_tuple has type `Literal[]`; used as `bool`.
239-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
240-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
241227
is_inputs_tuple: Literal[False] = False,
242228
) -> Union[Tensor, Tuple[Tensor, Tensor]]: ...
243229

‎tests/attr/helpers/attribution_delta_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from typing import Tuple, Union
55

66
import torch
7-
from captum._utils.typing import Tensor
87
from tests.helpers import BaseTest
8+
from torch import Tensor
99

1010

1111
def assert_attribution_delta(

‎tests/attr/layer/test_layer_lrp.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def test_lrp_basic_attributions(self) -> None:
6565
relevance, delta = lrp.attribute( # type: ignore
6666
inputs,
6767
classIndex.item(),
68-
# pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`.
6968
return_convergence_delta=True,
7069
)
7170
assertTensorAlmostEqual(
@@ -82,7 +81,6 @@ def test_lrp_simple_attributions(self) -> None:
8281
relevance_upper, delta = lrp_upper.attribute(
8382
inputs,
8483
attribute_to_layer_input=True,
85-
# pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`.
8684
return_convergence_delta=True,
8785
)
8886
lrp_lower = LayerLRP(model, model.linear)
@@ -185,7 +183,6 @@ def test_lrp_simple_attributions_all_layers_delta(self) -> None:
185183
relevance, delta = lrp.attribute(
186184
inputs,
187185
attribute_to_layer_input=True,
188-
# pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`.
189186
return_convergence_delta=True,
190187
)
191188
self.assertEqual(len(relevance), len(delta))

‎tests/attr/test_interpretable_input.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
# pyre-unsafe
44

5-
from typing import List, Optional, overload, Union
5+
from typing import List, Literal, Optional, overload, Union
66

77
import torch
8-
from captum._utils.typing import Literal
98
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
109
from parameterized import parameterized
1110
from tests.helpers import BaseTest
@@ -22,10 +21,7 @@ def __init__(self, vocab_list) -> None:
2221
@overload
2322
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
2423
@overload
25-
# pyre-fixme[43]: Incompatible overload. The implementation of
26-
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
27-
# pyre-ignore[11]: Annotation `pt` is not defined as a type
28-
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... # type: ignore # noqa: E501 line too long
24+
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
2925

3026
def encode(
3127
self, text: str, return_tensors: Optional[str] = "pt"

0 commit comments

Comments
 (0)
Please sign in to comment.