5
5
import random
6
6
import unittest
7
7
8
- from typing import Callable , Generator
8
+ from typing import Any , Callable , Generator , Tuple , TypeVar , Union
9
9
10
10
import numpy as np
11
11
import torch
12
12
from captum .log import patch_methods
13
13
from torch import Tensor
14
14
15
+ ReturnType = TypeVar ("ReturnType" )
15
16
16
- # pyre-fixme[3]: Return type must be annotated.
17
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
18
- def deep_copy_args (func : Callable ):
19
- # pyre-fixme[3]: Return type must be annotated.
20
- # pyre-fixme[2]: Parameter must be annotated.
21
- def copy_args (* args , ** kwargs ):
17
+
18
+ def deep_copy_args (func : Callable [..., ReturnType ]) -> Callable [..., ReturnType ]:
19
+ def copy_args (* args : Any , ** kwargs : Any ) -> ReturnType :
22
20
return func (
23
21
* (copy .deepcopy (x ) for x in args ),
24
22
** {k : copy .deepcopy (v ) for k , v in kwargs .items ()},
@@ -28,8 +26,7 @@ def copy_args(*args, **kwargs):
28
26
29
27
30
28
def assertTensorAlmostEqual (
31
- # pyre-fixme[2]: Parameter must be annotated.
32
- test ,
29
+ test : unittest .TestCase ,
33
30
# pyre-fixme[2]: Parameter must be annotated.
34
31
actual ,
35
32
# pyre-fixme[2]: Parameter must be annotated.
@@ -75,8 +72,7 @@ def assertTensorAlmostEqual(
75
72
76
73
77
74
def assertTensorTuplesAlmostEqual (
78
- # pyre-fixme[2]: Parameter must be annotated.
79
- test ,
75
+ test : unittest .TestCase ,
80
76
# pyre-fixme[2]: Parameter must be annotated.
81
77
actual ,
82
78
# pyre-fixme[2]: Parameter must be annotated.
@@ -95,15 +91,17 @@ def assertTensorTuplesAlmostEqual(
95
91
assertTensorAlmostEqual (test , actual , expected , delta , mode )
96
92
97
93
98
- # pyre-fixme[2]: Parameter must be annotated.
99
- def assertAttributionComparision (test , attributions1 , attributions2 ) -> None :
94
+ def assertAttributionComparision (
95
+ test : unittest .TestCase ,
96
+ attributions1 : Union [Tensor , Tuple [Tensor , ...]],
97
+ attributions2 : Union [Tensor , Tuple [Tensor , ...]],
98
+ ) -> None :
100
99
for attribution1 , attribution2 in zip (attributions1 , attributions2 ):
101
100
for attr_row1 , attr_row2 in zip (attribution1 , attribution2 ):
102
101
assertTensorAlmostEqual (test , attr_row1 , attr_row2 , 0.05 , "max" )
103
102
104
103
105
- # pyre-fixme[2]: Parameter must be annotated.
106
- def assert_delta (test , delta ) -> None :
104
+ def assert_delta (test : unittest .TestCase , delta : Tensor ) -> None :
107
105
delta_condition = (delta .abs () < 0.00001 ).all ()
108
106
test .assertTrue (
109
107
delta_condition ,
0 commit comments