Skip to content

Commit 4e63118

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix InputXGradient pyre fixme issues (#1607)
Summary: Pull Request resolved: #1607 Fixing unresolved pyre fixme issues in corresponding file Differential Revision: D76736961
1 parent 6154f28 commit 4e63118

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

captum/attr/_core/input_x_gradient.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Callable, Optional
4+
from typing import Callable, cast, Optional
55

66
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
77
from captum._utils.gradient import (
@@ -126,12 +126,11 @@ def attribute(
126126
)
127127

128128
undo_gradient_requirements(inputs_tuple, gradient_mask)
129-
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
130-
# `Tuple[Tensor, ...]`.
131-
return _format_output(is_inputs_tuple, attributions)
129+
return cast(
130+
TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attributions)
131+
)
132132

133-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
134-
def attribute_future(self) -> Callable:
133+
def attribute_future(self) -> None:
135134
r"""
136135
This method is not implemented for InputXGradient.
137136
"""

tests/attr/test_input_x_gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_futures_not_implemented(self) -> None:
5555
input_x_grad = InputXGradient(model.forward)
5656
attributions = None
5757
with self.assertRaises(NotImplementedError):
58-
attributions = input_x_grad.attribute_future()
58+
attributions = input_x_grad.attribute_future() # type: ignore
5959
self.assertEqual(attributions, None)
6060

6161
def _input_x_gradient_base_assert(

0 commit comments

Comments
 (0)