Skip to content

Correct remaining typing.Literal imports #1412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 15 additions & 27 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
@@ -5,13 +5,23 @@
from enum import Enum
from functools import reduce
from inspect import signature
from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Literal,
overload,
Sequence,
Tuple,
Union,
)

import numpy as np
import torch
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
TupleOrTensorOrBoolGeneric,
@@ -71,23 +81,17 @@ def safe_div(


@typing.overload
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
# is incompatible with the return type of the implementation (`bool`).
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...


@typing.overload
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
# is incompatible with the return type of the implementation (`bool`).
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...
def _is_tuple(
inputs: TensorOrTupleOfTensorsGeneric,
) -> bool: ... # type: ignore


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
@@ -480,22 +484,14 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:


@typing.overload
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
# possible arguments of overload defined on line `449`.
def _format_output(
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[True],
output: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
# possible arguments of overload defined on line `455`.
def _format_output(
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[False],
output: Tuple[Tensor, ...],
) -> Tensor: ...
@@ -526,22 +522,14 @@ def _format_output(


@typing.overload
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
# possible arguments of overload defined on line `483`.
def _format_outputs(
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_multiple_inputs: Literal[False],
outputs: List[Tuple[Tensor, ...]],
) -> Union[Tensor, Tuple[Tensor, ...]]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
# possible arguments of overload defined on line `489`.
def _format_outputs(
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_multiple_inputs: Literal[True],
outputs: List[Tuple[Tensor, ...]],
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ...
20 changes: 12 additions & 8 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,18 @@
import typing
import warnings
from collections import defaultdict
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)

import torch
from captum._utils.common import (
@@ -16,7 +27,6 @@
)
from captum._utils.sample_gradient import SampleGradientWrapper
from captum._utils.typing import (
Literal,
ModuleOrModuleList,
TargetType,
TensorOrTupleOfTensorsGeneric,
@@ -226,9 +236,6 @@ def _forward_layer_distributed_eval(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
# pyre-fixme[9]: forward_hook_with_return has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
forward_hook_with_return: Literal[False] = False,
require_layer_grads: bool = False,
) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]: ...
@@ -246,8 +253,6 @@ def _forward_layer_distributed_eval(
additional_forward_args: Any = None,
attribute_to_layer_input: bool = False,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
forward_hook_with_return: Literal[True],
require_layer_grads: bool = False,
) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]: ...
@@ -675,7 +680,6 @@ def compute_layer_gradients_and_eval(
target_ind=target_ind,
additional_forward_args=additional_forward_args,
attribute_to_layer_input=attribute_to_layer_input,
# pyre-fixme[6]: For 7th argument expected `Literal[]` but got `bool`.
forward_hook_with_return=True,
require_layer_grads=True,
)
10 changes: 1 addition & 9 deletions captum/_utils/progress.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,7 @@
import sys
import warnings
from time import time
from typing import Any, cast, Iterable, Optional, Sized, TextIO

from captum._utils.typing import Literal
from typing import Any, cast, Iterable, Literal, Optional, Sized, TextIO

try:
from tqdm.auto import tqdm
@@ -75,10 +73,7 @@ def __enter__(self) -> "NullProgress":
return self

# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
# pyre-fixme[7]: Expected `Literal[]` but got `bool`.
return False

# pyre-fixme[3]: Return type must be annotated.
@@ -139,11 +134,8 @@ def __enter__(self) -> "SimpleProgress":
return self

# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
self.close()
# pyre-fixme[7]: Expected `Literal[]` but got `bool`.
return False

# pyre-fixme[3]: Return type must be annotated.
9 changes: 2 additions & 7 deletions captum/attr/_core/layer/layer_conductance.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
from captum._utils.common import (
@@ -12,7 +12,7 @@
_format_output,
)
from captum._utils.gradient import compute_layer_gradients_and_eval
from captum._utils.typing import BaselineType, Literal, TargetType
from captum._utils.typing import BaselineType, TargetType
from captum.attr._utils.approximation_methods import approximation_parameters
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
from captum.attr._utils.batching import _batch_attribution
@@ -86,8 +86,6 @@ def attribute(
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
grad_kwargs: Optional[Dict[str, Any]] = None,
@@ -105,9 +103,6 @@ def attribute(
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
grad_kwargs: Optional[Dict[str, Any]] = None,
29 changes: 2 additions & 27 deletions captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, cast, Dict, Literal, Optional, Sequence, Tuple, Union

import torch
from captum._utils.common import (
@@ -13,12 +13,7 @@
ExpansionTypes,
)
from captum._utils.gradient import compute_layer_gradients_and_eval
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
from captum.attr._utils.attribution import LayerAttribution
from captum.attr._utils.common import (
@@ -101,8 +96,6 @@ def __init__(

# Ignoring mypy error for inconsistent signature with DeepLift
@typing.overload # type: ignore
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `117`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -111,27 +104,20 @@ def attribute(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
grad_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `104`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -382,8 +368,6 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
inputs,
additional_forward_args,
target,
# pyre-fixme[31]: Expression `Literal[False])]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
)

@@ -464,8 +448,6 @@ def attribute(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -483,9 +465,6 @@ def attribute(
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -686,10 +665,6 @@ def attribute(
target=exp_target,
additional_forward_args=exp_addit_args,
return_convergence_delta=cast(
# pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid
# type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take
# parameters.
Literal[True, False],
return_convergence_delta,
),
20 changes: 2 additions & 18 deletions captum/attr/_core/layer/layer_gradient_shap.py
Original file line number Diff line number Diff line change
@@ -3,12 +3,12 @@
# pyre-strict

import typing
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import torch
from captum._utils.gradient import _forward_layer_eval, compute_layer_gradients_and_eval
from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.gradient_shap import _scale_input
from captum.attr._core.noise_tunnel import NoiseTunnel
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
@@ -117,8 +117,6 @@ def attribute(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
@@ -135,9 +133,6 @@ def attribute(
stdevs: Union[float, Tuple[float, ...]] = 0.0,
target: TargetType = None,
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
@@ -392,8 +387,6 @@ def __init__(
self._multiply_by_inputs = multiply_by_inputs

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `385`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -402,26 +395,19 @@ def attribute(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
grad_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `373`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
baselines: Union[Tensor, Tuple[Tensor, ...]],
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
grad_kwargs: Optional[Dict[str, Any]] = None,
@@ -505,8 +491,6 @@ def attribute( # type: ignore
inputs,
additional_forward_args,
target,
# pyre-fixme[31]: Expression `Literal[False])]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
)

12 changes: 2 additions & 10 deletions captum/attr/_core/layer/layer_integrated_gradients.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
# pyre-strict
import functools
import warnings
from typing import Any, Callable, cast, List, overload, Tuple, Union
from typing import Any, Callable, cast, List, Literal, overload, Tuple, Union

import torch
from captum._utils.common import (
@@ -12,7 +12,7 @@
_format_outputs,
)
from captum._utils.gradient import _forward_layer_eval, _run_forward
from captum._utils.typing import BaselineType, Literal, ModuleOrModuleList, TargetType
from captum._utils.typing import BaselineType, ModuleOrModuleList, TargetType
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
from captum.attr._utils.common import (
@@ -227,8 +227,6 @@ def layer_forward_hook(
return _gradient_func

@overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `112`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -239,15 +237,11 @@ def attribute(
n_steps: int,
method: str,
internal_batch_size: Union[None, int],
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False],
attribute_to_layer_input: bool,
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ...

@overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `126`.
def attribute( # type: ignore
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -258,8 +252,6 @@ def attribute( # type: ignore
n_steps: int,
method: str,
internal_batch_size: Union[None, int],
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool,
) -> Tuple[
12 changes: 1 addition & 11 deletions captum/attr/_core/layer/layer_lrp.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

# pyre-strict
import typing
from typing import Any, cast, List, Tuple, Union
from typing import Any, cast, List, Literal, Tuple, Union

from captum._utils.common import (
_format_tensor_into_tuples,
@@ -15,7 +15,6 @@
undo_gradient_requirements,
)
from captum._utils.typing import (
Literal,
ModuleOrModuleList,
TargetType,
TensorOrTupleOfTensorsGeneric,
@@ -64,17 +63,13 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None:
self.device_ids = cast(List[int], self.model.device_ids)

@typing.overload # type: ignore
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `77`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
verbose: bool = False,
@@ -84,17 +79,12 @@ def attribute(
]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `66`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
verbose: bool = False,
18 changes: 2 additions & 16 deletions captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
# pyre-strict
import typing
from inspect import signature
from typing import Any, Callable, List, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, List, Literal, Tuple, TYPE_CHECKING, Union

import torch
from captum._utils.common import (
@@ -12,12 +12,7 @@
_format_tensor_into_tuples,
_validate_input as _validate_input_basic,
)
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS
from torch import Tensor

@@ -206,8 +201,6 @@ def _format_and_verify_sliding_window_shapes(


@typing.overload
# pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does
# not accept all possible arguments of overload defined on line `212`.
def _compute_conv_delta_and_format_attrs(
attr_algo: "GradientAttribution",
return_convergence_delta: bool,
@@ -217,15 +210,11 @@ def _compute_conv_delta_and_format_attrs(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
target: TargetType,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[True],
) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does
# not accept all possible arguments of overload defined on line `199`.
def _compute_conv_delta_and_format_attrs(
attr_algo: "GradientAttribution",
return_convergence_delta: bool,
@@ -235,9 +224,6 @@ def _compute_conv_delta_and_format_attrs(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
target: TargetType,
# pyre-fixme[9]: is_inputs_tuple has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[False] = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]: ...

2 changes: 1 addition & 1 deletion tests/attr/helpers/attribution_delta_util.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,8 @@
from typing import Tuple, Union

import torch
from captum._utils.typing import Tensor
from tests.helpers import BaseTest
from torch import Tensor


def assert_attribution_delta(
3 changes: 0 additions & 3 deletions tests/attr/layer/test_layer_lrp.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,6 @@ def test_lrp_basic_attributions(self) -> None:
relevance, delta = lrp.attribute( # type: ignore
inputs,
classIndex.item(),
# pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`.
return_convergence_delta=True,
)
assertTensorAlmostEqual(
@@ -82,7 +81,6 @@ def test_lrp_simple_attributions(self) -> None:
relevance_upper, delta = lrp_upper.attribute(
inputs,
attribute_to_layer_input=True,
# pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`.
return_convergence_delta=True,
)
lrp_lower = LayerLRP(model, model.linear)
@@ -185,7 +183,6 @@ def test_lrp_simple_attributions_all_layers_delta(self) -> None:
relevance, delta = lrp.attribute(
inputs,
attribute_to_layer_input=True,
# pyre-fixme[6]: For 3rd argument expected `Literal[]` but got `bool`.
return_convergence_delta=True,
)
self.assertEqual(len(relevance), len(delta))
8 changes: 2 additions & 6 deletions tests/attr/test_interpretable_input.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,9 @@

# pyre-unsafe

from typing import List, Optional, overload, Union
from typing import List, Literal, Optional, overload, Union

import torch
from captum._utils.typing import Literal
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
from parameterized import parameterized
from tests.helpers import BaseTest
@@ -22,10 +21,7 @@ def __init__(self, vocab_list) -> None:
@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
@overload
# pyre-fixme[43]: Incompatible overload. The implementation of
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
# pyre-ignore[11]: Annotation `pt` is not defined as a type
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... # type: ignore # noqa: E501 line too long
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...

def encode(
self, text: str, return_tensors: Optional[str] = "pt"