Skip to content

Commit d5b99d2

Browse files
yucufacebook-github-bot
authored andcommitted
Add capability to pass additional grad_kwargs for LayerGradientXActivation (#1286)
Summary: Pull Request resolved: #1286 Differential Revision: D57756842
1 parent 1157ad8 commit d5b99d2

File tree

5 files changed

+47
-2
lines changed

5 files changed

+47
-2
lines changed

captum/_utils/gradient.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def compute_layer_gradients_and_eval(
485485
device_ids: Union[None, List[int]] = None,
486486
attribute_to_layer_input: bool = False,
487487
output_fn: Union[None, Callable] = None,
488+
grad_kwargs: Optional[Dict[str, Any]] = None,
488489
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]]: ...
489490

490491

@@ -499,6 +500,7 @@ def compute_layer_gradients_and_eval(
499500
device_ids: Union[None, List[int]] = None,
500501
attribute_to_layer_input: bool = False,
501502
output_fn: Union[None, Callable] = None,
503+
grad_kwargs: Optional[Dict[str, Any]] = None,
502504
) -> Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]]: ...
503505

504506

@@ -513,6 +515,7 @@ def compute_layer_gradients_and_eval(
513515
device_ids: Union[None, List[int]] = None,
514516
attribute_to_layer_input: bool = False,
515517
output_fn: Union[None, Callable] = None,
518+
grad_kwargs: Optional[Dict[str, Any]] = None,
516519
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ...
517520

518521

@@ -528,6 +531,7 @@ def compute_layer_gradients_and_eval(
528531
device_ids: Union[None, List[int]] = None,
529532
attribute_to_layer_input: bool = False,
530533
output_fn: Union[None, Callable] = None,
534+
grad_kwargs: Optional[Dict[str, Any]] = None,
531535
) -> Union[
532536
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
533537
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]],
@@ -572,6 +576,7 @@ def compute_layer_gradients_and_eval(
572576
args: Additional input arguments that forward function requires.
573577
It takes an empty tuple (no additional arguments) if no
574578
additional arguments are required
579+
grad_kwargs: Additional keyword arguments for torch.autograd.grad
575580
576581
577582
Returns:
@@ -640,7 +645,11 @@ def compute_layer_gradients_and_eval(
640645
for device_id in key_list
641646
for layer_tensor in saved_layer[single_layer][device_id]
642647
)
643-
saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs)
648+
saved_grads = torch.autograd.grad(
649+
outputs=torch.unbind(output),
650+
inputs=grad_inputs,
651+
**grad_kwargs or {},
652+
)
644653

645654
offset = 0
646655
all_grads: List[Tuple[Tensor, ...]] = []

captum/attr/_core/layer/layer_gradient_x_activation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Any, Callable, List, Tuple, Union
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
33

44
from captum._utils.common import (
55
_format_additional_forward_args,
@@ -76,6 +76,7 @@ def attribute(
7676
target: TargetType = None,
7777
additional_forward_args: Any = None,
7878
attribute_to_layer_input: bool = False,
79+
grad_kwargs: Optional[Dict[str, Any]] = None,
7980
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
8081
r"""
8182
Args:
@@ -132,6 +133,7 @@ def attribute(
132133
layer input, otherwise it will be computed with respect
133134
to layer output.
134135
Default: False
136+
grad_kwargs: Additional keyword arguments for torch.autograd.grad
135137
136138
Returns:
137139
*Tensor* or *tuple[Tensor, ...]* or list of **attributions**:
@@ -175,6 +177,7 @@ def attribute(
175177
additional_forward_args,
176178
device_ids=self.device_ids,
177179
attribute_to_layer_input=attribute_to_layer_input,
180+
grad_kwargs=grad_kwargs,
178181
)
179182
if isinstance(self.layer, Module):
180183
return _format_output(

tests/attr/layer/test_layer_gradient_x_activation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,18 @@ def test_gradient_activation_embedding_no_grad(self) -> None:
129129
list(layer_act.attribute(inputs=(input1, input2)).shape), [4, 100]
130130
)
131131

132+
def test_simple_multi_gradient_activation_with_unused_layer(self) -> None:
133+
model = BasicModel_MultiLayer(multi_input_module=True)
134+
test_input1 = torch.tensor([[3.0, 4.0, 0.0]], requires_grad=True)
135+
# test_input2 = torch.tensor([[0.0, 4.0, 5.0]], requires_grad=True)
136+
layer_act = LayerGradientXActivation(model, [model.linear1, model.relu])
137+
attributions = layer_act.attribute(
138+
inputs=test_input1, target=0, grad_kwargs={"materialize_grads": True}
139+
)
140+
self.assertEqual(len(attributions), 2)
141+
self.assertEqual(list(attributions[0].shape), [1, 4])
142+
self.assertEqual(list(attributions[1].shape), [1, 4])
143+
132144
def _layer_activation_test_assert(
133145
self,
134146
model: Module,

tests/helpers/basic_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,10 @@ def forward(
407407
if self.multi_input_module:
408408
relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input))
409409
relu_out = relu_out1 + relu_out2
410+
# relu is not used when multi_input_module set to True,
411+
# so this is to set an unsued layer intentionally for testing
412+
# and it won't be part of return
413+
self.relu(lin1_out)
410414
else:
411415
relu_out = self.relu(lin1_out)
412416
lin2_out = self.linear2(relu_out)

tests/utils/test_gradient.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,20 @@ def test_layer_gradient_output(self) -> None:
243243
)
244244
assertTensorAlmostEqual(self, grads[0], [[0.0, 1.0]], delta=0.01, mode="max")
245245
assertTensorAlmostEqual(self, eval[0], [[26.0, 28.0]], delta=0.01, mode="max")
246+
247+
def test_layer_gradient_unused_layer(self) -> None:
248+
model = BasicModel_MultiLayer(multi_input_module=True)
249+
input = torch.tensor([[5.0, 2.0, 1.0]], requires_grad=True)
250+
grads, eval = compute_layer_gradients_and_eval(
251+
model,
252+
[model.linear1, model.relu],
253+
input,
254+
target_ind=1,
255+
grad_kwargs={"materialize_grads": True},
256+
)
257+
assertTensorAlmostEqual(
258+
self, grads[0][0], [[0.0, 1.0, 1.0, 1.0]], delta=0, mode="max"
259+
)
260+
assertTensorAlmostEqual(
261+
self, eval[0][0], [[-2.0, 9.0, 9.0, 9.0]], delta=0, mode="max"
262+
)

0 commit comments

Comments
 (0)