Skip to content

Commit f8c25a4

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix testing/helpers/basic pyre fixme issues (#1603)
Summary: Pull Request resolved: #1603 Fixing unresolved pyre fixme issues in corresponding file Reviewed By: styusuf Differential Revision: D76737459 fbshipit-source-id: d4a9d4abc0799017874c2dd5e73cdfe4dc005691
1 parent e39dd6a commit f8c25a4

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

captum/testing/helpers/basic.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,18 @@
55
import random
66
import unittest
77

8-
from typing import Callable, Generator
8+
from typing import Any, Callable, Generator, Tuple, TypeVar, Union
99

1010
import numpy as np
1111
import torch
1212
from captum.log import patch_methods
1313
from torch import Tensor
1414

15+
ReturnType = TypeVar("ReturnType")
1516

16-
# pyre-fixme[3]: Return type must be annotated.
17-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
18-
def deep_copy_args(func: Callable):
19-
# pyre-fixme[3]: Return type must be annotated.
20-
# pyre-fixme[2]: Parameter must be annotated.
21-
def copy_args(*args, **kwargs):
17+
18+
def deep_copy_args(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
19+
def copy_args(*args: Any, **kwargs: Any) -> ReturnType:
2220
return func(
2321
*(copy.deepcopy(x) for x in args),
2422
**{k: copy.deepcopy(v) for k, v in kwargs.items()},
@@ -28,8 +26,7 @@ def copy_args(*args, **kwargs):
2826

2927

3028
def assertTensorAlmostEqual(
31-
# pyre-fixme[2]: Parameter must be annotated.
32-
test,
29+
test: unittest.TestCase,
3330
# pyre-fixme[2]: Parameter must be annotated.
3431
actual,
3532
# pyre-fixme[2]: Parameter must be annotated.
@@ -75,8 +72,7 @@ def assertTensorAlmostEqual(
7572

7673

7774
def assertTensorTuplesAlmostEqual(
78-
# pyre-fixme[2]: Parameter must be annotated.
79-
test,
75+
test: unittest.TestCase,
8076
# pyre-fixme[2]: Parameter must be annotated.
8177
actual,
8278
# pyre-fixme[2]: Parameter must be annotated.
@@ -95,15 +91,17 @@ def assertTensorTuplesAlmostEqual(
9591
assertTensorAlmostEqual(test, actual, expected, delta, mode)
9692

9793

98-
# pyre-fixme[2]: Parameter must be annotated.
99-
def assertAttributionComparision(test, attributions1, attributions2) -> None:
94+
def assertAttributionComparision(
95+
test: unittest.TestCase,
96+
attributions1: Union[Tensor, Tuple[Tensor, ...]],
97+
attributions2: Union[Tensor, Tuple[Tensor, ...]],
98+
) -> None:
10099
for attribution1, attribution2 in zip(attributions1, attributions2):
101100
for attr_row1, attr_row2 in zip(attribution1, attribution2):
102101
assertTensorAlmostEqual(test, attr_row1, attr_row2, 0.05, "max")
103102

104103

105-
# pyre-fixme[2]: Parameter must be annotated.
106-
def assert_delta(test, delta) -> None:
104+
def assert_delta(test: unittest.TestCase, delta: Tensor) -> None:
107105
delta_condition = (delta.abs() < 0.00001).all()
108106
test.assertTrue(
109107
delta_condition,

0 commit comments

Comments
 (0)