1
1
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
2
3
3
# pyre-strict
4
- from typing import Any , Tuple
4
+ from typing import List , Tuple
5
5
6
6
import torch
7
7
from captum ._utils .gradient import compute_gradients
10
10
from torch .nn import Module
11
11
12
12
13
- # pyre-fixme[3]: Return annotation cannot contain `Any`.
14
- def get_basic_config () -> Tuple [Module , Tensor , Tensor , Any ]:
13
+ def get_basic_config () -> Tuple [Module , Tensor , Tensor , None ]:
15
14
input = torch .tensor ([1.0 , 2.0 , 3.0 , 0.0 , - 1.0 , 7.0 ], requires_grad = True ).T
16
15
# manually percomputed gradients
17
16
grads = torch .tensor ([- 0.0 , - 0.0 , - 0.0 , 1.0 , 1.0 , - 0.0 ])
18
17
return BasicModel (), input , grads , None
19
18
20
19
21
- # pyre-fixme[3]: Return annotation cannot contain `Any`.
22
20
def get_multiargs_basic_config () -> (
23
- Tuple [Module , Tuple [Tensor , ...], Tuple [Tensor , ...], Any ]
21
+ Tuple [Module , Tuple [Tensor , ...], Tuple [Tensor , ...], Tuple [ List [ int ], int ] ]
24
22
):
25
23
model = BasicModel5_MultiArgs ()
26
24
additional_forward_args = ([2 , 3 ], 1 )
@@ -34,9 +32,8 @@ def get_multiargs_basic_config() -> (
34
32
return model , inputs , grads , additional_forward_args
35
33
36
34
37
- # pyre-fixme[3]: Return annotation cannot contain `Any`.
38
35
def get_multiargs_basic_config_large () -> (
39
- Tuple [Module , Tuple [Tensor , ...], Tuple [Tensor , ...], Any ]
36
+ Tuple [Module , Tuple [Tensor , ...], Tuple [Tensor , ...], Tuple [ List [ int ], int ] ]
40
37
):
41
38
model = BasicModel5_MultiArgs ()
42
39
additional_forward_args = ([2 , 3 ], 1 )
0 commit comments