Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1a07286

Browse files
vivekmigfacebook-github-bot
authored andcommittedJun 16, 2025
Fix IntegratedGradients pyre fixme issues
Summary: Fixing unresolved pyre fixme issues in corresponding file Differential Revision: D76736985
1 parent 79c01b8 commit 1a07286

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed
 

‎captum/attr/_core/integrated_gradients.py‎

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Callable, List, Literal, Optional, Tuple, Union
5+
from typing import Callable, cast, List, Literal, Optional, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -301,16 +301,18 @@ def attribute( # type: ignore
301301
additional_forward_args=additional_forward_args,
302302
target=target,
303303
)
304-
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
305-
return _format_output(is_inputs_tuple, attributions), delta
306-
# pyre-fixme[7]: Expected
307-
# `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
308-
# typing.Tuple[Tensor, ...]]], Tensor], Variable[TensorOrTupleOfTensorsGeneric
309-
# <: [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Tuple[Tensor, ...]`.
310-
return _format_output(is_inputs_tuple, attributions)
311-
312-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
313-
def attribute_future(self) -> Callable:
304+
return (
305+
cast(
306+
TensorOrTupleOfTensorsGeneric,
307+
_format_output(is_inputs_tuple, attributions),
308+
),
309+
delta,
310+
)
311+
return cast(
312+
TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attributions)
313+
)
314+
315+
def attribute_future(self) -> None:
314316
r"""
315317
This method is not implemented for IntegratedGradients.
316318
"""

‎tests/attr/test_integrated_gradients_basic.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_futures_not_implemented(self) -> None:
161161
ig = IntegratedGradients(model, multiply_by_inputs=True)
162162
attributions = None
163163
with self.assertRaises(NotImplementedError):
164-
attributions = ig.attribute_future()
164+
attributions = ig.attribute_future() # type: ignore
165165
self.assertEqual(attributions, None)
166166

167167
def _assert_multi_variable(

0 commit comments

Comments
 (0)
Please sign in to comment.