Skip to content

Commit 37b2ca5

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Migrate to register_full_backward_hook (#837)
Summary: Modifying all module backward hooks to utilize the new register_full_backward hook API documented here. This new API resolves many issues we previously encountered with backward module hooks. Since this API is available only in torch 1.8, allowing a fall-back option to the original backward hook approach. Due to issues described [here](pytorch/pytorch#57157), we are also deprecating attribution with respect to neuron outputs for NeuronDeepLift, NeuronGuidedBackprop, and NeuronDeconvolution; these methods require attributing with respect to neuron input (which is typically equivalent to attributing with respect to the previous layer output). Additionally, in-place modules are no longer supported for full backward hooks, so these are no longer supported for DeepLift, LRP, and GuidedBackprop / Deconvolution and corresponding variants. Documentation has been updated accordingly. Pull Request resolved: #837 Reviewed By: NarineK Differential Revision: D34380993 Pulled By: vivekmig fbshipit-source-id: 8568d5f3783d6c05f76cfbf9f43f5276c7b30930
1 parent dad55b0 commit 37b2ca5

17 files changed

+153
-103
lines changed

captum/_utils/gradient.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,27 @@ def undo_gradient_requirements(
8484
input.requires_grad_(False)
8585

8686

87+
def register_backward_hook(
88+
module: Module, hook: Callable, attr_obj: Any
89+
) -> torch.utils.hooks.RemovableHandle:
90+
# Special case for supporting output attributions for neuron methods
91+
# This can be removed after deprecation of neuron output attributions
92+
# for NeuronDeepLift, NeuronDeconvolution, and NeuronGuidedBackprop
93+
# in v0.6.0
94+
if (
95+
hasattr(attr_obj, "skip_new_hook_layer")
96+
and attr_obj.skip_new_hook_layer == module
97+
):
98+
return module.register_backward_hook(hook)
99+
100+
try:
101+
# Only supported for torch >= 1.8
102+
return module.register_full_backward_hook(hook)
103+
except AttributeError:
104+
# Fallback for previous versions of PyTorch
105+
return module.register_backward_hook(hook)
106+
107+
87108
def compute_gradients(
88109
forward_fn: Callable,
89110
inputs: Union[Tensor, Tuple[Tensor, ...]],

captum/attr/_core/deep_lift.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from captum._utils.gradient import (
2323
apply_gradient_requirements,
24+
register_backward_hook,
2425
undo_gradient_requirements,
2526
)
2627
from captum._utils.typing import (
@@ -112,7 +113,10 @@ def __init__(
112113
r"""
113114
Args:
114115
115-
model (nn.Module): The reference to PyTorch model instance.
116+
model (nn.Module): The reference to PyTorch model instance. Model cannot
117+
contain any in-place nonlinear submodules; these are not
118+
supported by the register_full_backward_hook PyTorch API
119+
starting from PyTorch v1.8.
116120
multiply_by_inputs (bool, optional): Indicates whether to factor
117121
model inputs' multiplier in the final attribution scores.
118122
In the literature this is also known as local vs global
@@ -542,7 +546,7 @@ def _register_hooks(
542546
# adds forward hook to leaf nodes that are non-linear
543547
forward_handle = module.register_forward_hook(self._forward_hook)
544548
pre_forward_handle = module.register_forward_pre_hook(self._forward_pre_hook)
545-
backward_handle = module.register_backward_hook(self._backward_hook)
549+
backward_handle = register_backward_hook(module, self._backward_hook, self)
546550
self.forward_handles.append(forward_handle)
547551
self.forward_handles.append(pre_forward_handle)
548552
self.backward_handles.append(backward_handle)
@@ -622,7 +626,9 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
622626
r"""
623627
Args:
624628
625-
model (nn.Module): The reference to PyTorch model instance.
629+
model (nn.Module): The reference to PyTorch model instance. Model cannot
630+
contain any in-place nonlinear submodules; these are not
631+
supported by the register_full_backward_hook PyTorch API.
626632
multiply_by_inputs (bool, optional): Indicates whether to factor
627633
model inputs' multiplier in the final attribution scores.
628634
In the literature this is also known as local vs global

captum/attr/_core/guided_backprop_deconvnet.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from captum._utils.common import _format_input, _format_output, _is_tuple
88
from captum._utils.gradient import (
99
apply_gradient_requirements,
10+
register_backward_hook,
1011
undo_gradient_requirements,
1112
)
1213
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
@@ -74,7 +75,7 @@ def attribute(
7475

7576
def _register_hooks(self, module: Module):
7677
if isinstance(module, torch.nn.ReLU):
77-
hook = module.register_backward_hook(self._backward_hook)
78+
hook = register_backward_hook(module, self._backward_hook, self)
7879
self.backward_hooks.append(hook)
7980

8081
def _backward_hook(
@@ -116,7 +117,9 @@ def __init__(self, model: Module) -> None:
116117
r"""
117118
Args:
118119
119-
model (nn.Module): The reference to PyTorch model instance.
120+
model (nn.Module): The reference to PyTorch model instance. Model cannot
121+
contain any in-place ReLU submodules; these are not
122+
supported by the register_full_backward_hook PyTorch API.
120123
"""
121124
ModifiedReluGradientAttribution.__init__(
122125
self, model, use_relu_grad_output=False
@@ -227,7 +230,9 @@ def __init__(self, model: Module) -> None:
227230
r"""
228231
Args:
229232
230-
model (nn.Module): The reference to PyTorch model instance.
233+
model (nn.Module): The reference to PyTorch model instance. Model cannot
234+
contain any in-place ReLU submodules; these are not
235+
supported by the register_full_backward_hook PyTorch API.
231236
"""
232237
ModifiedReluGradientAttribution.__init__(self, model, use_relu_grad_output=True)
233238

captum/attr/_core/guided_grad_cam.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def __init__(
5151
r"""
5252
Args:
5353
54-
model (nn.Module): The reference to PyTorch model instance.
54+
model (nn.Module): The reference to PyTorch model instance. Model cannot
55+
contain any in-place ReLU submodules; these are not
56+
supported by the register_full_backward_hook PyTorch API
57+
starting from PyTorch v1.8.
5558
layer (torch.nn.Module): Layer for which GradCAM attributions are computed.
5659
Currently, only layers with a single tensor output are
5760
supported.
@@ -194,6 +197,7 @@ def attribute(
194197
"outputs is not supported."
195198
)
196199
grad_cam_attr = grad_cam_attr[0]
200+
197201
guided_backprop_attr = self.guided_backprop.attribute.__wrapped__(
198202
self.guided_backprop, # self
199203
inputs=inputs,

captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def __init__(
6969
r"""
7070
Args:
7171
72-
model (torch.nn.Module): The reference to PyTorch model instance.
72+
model (nn.Module): The reference to PyTorch model instance. Model cannot
73+
contain any in-place nonlinear submodules; these are not
74+
supported by the register_full_backward_hook PyTorch API
75+
starting from PyTorch v1.8.
7376
layer (torch.nn.Module): Layer for which attributions are computed.
7477
The size and dimensionality of the attributions
7578
corresponds to the size and dimensionality of the layer's
@@ -397,7 +400,10 @@ def __init__(
397400
r"""
398401
Args:
399402
400-
model (torch.nn.Module): The reference to PyTorch model instance.
403+
model (nn.Module): The reference to PyTorch model instance. Model cannot
404+
contain any in-place nonlinear submodules; these are not
405+
supported by the register_full_backward_hook PyTorch API
406+
starting from PyTorch v1.8.
401407
layer (torch.nn.Module): Layer for which attributions are computed.
402408
The size and dimensionality of the attributions
403409
corresponds to the size and dimensionality of the layer's

captum/attr/_core/layer/layer_lrp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None:
4242
any modification of it. Custom rules for a given layer need to
4343
be defined as attribute
4444
`module.rule` and need to be of type PropagationRule.
45+
Model cannot contain any in-place nonlinear submodules;
46+
these are not supported by the register_full_backward_hook
47+
PyTorch API starting from PyTorch v1.8.
48+
49+
4550
layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers
4651
for which attributions are computed.
4752
The size and dimensionality of the attributions

captum/attr/_core/lrp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from captum._utils.common import _format_input, _format_output, _is_tuple, _run_forward
99
from captum._utils.gradient import (
1010
apply_gradient_requirements,
11+
register_backward_hook,
1112
undo_gradient_requirements,
1213
)
1314
from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric
@@ -43,7 +44,10 @@ def __init__(self, model: Module) -> None:
4344
it. Custom rules for a given layer need to be defined as attribute
4445
`module.rule` and need to be of type PropagationRule. If no rule is
4546
specified for a layer, a pre-defined default rule for the module type
46-
is used.
47+
is used. Model cannot contain any in-place nonlinear submodules;
48+
these are not supported by the register_full_backward_hook
49+
PyTorch API starting from PyTorch v1.8.
50+
4751
"""
4852
GradientAttribution.__init__(self, model)
4953
self.model = model
@@ -305,8 +309,8 @@ def _check_rules(self) -> None:
305309
def _register_forward_hooks(self) -> None:
306310
for layer in self.layers:
307311
if type(layer) in SUPPORTED_NON_LINEAR_LAYERS:
308-
backward_handle = layer.register_backward_hook(
309-
PropagationRule.backward_hook_activation
312+
backward_handle = register_backward_hook(
313+
layer, PropagationRule.backward_hook_activation, self
310314
)
311315
self.backward_handles.append(backward_handle)
312316
else:

captum/attr/_core/neuron/neuron_deep_lift.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
import warnings
23
from typing import Any, Callable, Tuple, Union, cast
34

45
from captum._utils.gradient import construct_neuron_grad_fn
@@ -45,7 +46,10 @@ def __init__(
4546
r"""
4647
Args:
4748
48-
model (torch.nn.Module): The reference to PyTorch model instance.
49+
model (nn.Module): The reference to PyTorch model instance. Model cannot
50+
contain any in-place nonlinear submodules; these are not
51+
supported by the register_full_backward_hook PyTorch API
52+
starting from PyTorch v1.8.
4953
layer (torch.nn.Module): Layer for which neuron attributions are computed.
5054
Attributions for a particular neuron for the input or output
5155
of this layer are computed using the argument neuron_selector
@@ -227,6 +231,17 @@ def attribute(
227231
>>> attribution = dl.attribute(input, (4,1,2))
228232
"""
229233
dl = DeepLift(cast(Module, self.forward_func), self.multiplies_by_inputs)
234+
if not attribute_to_neuron_input:
235+
warnings.warn(
236+
"Attribution to neuron output is no longer supported for"
237+
" NeuronDeepLift and will be deprecated in Captum"
238+
" 0.6.0 due to changes in PyTorch's full backward hook"
239+
" behavior. To obtain attributions for a neuron's"
240+
" output, please attribute with respect to the next layer's input"
241+
)
242+
dl.skip_new_hook_layer = self.layer # type: ignore
243+
else:
244+
dl.skip_new_hook_layer = None # type: ignore
230245
dl.gradient_func = construct_neuron_grad_fn(
231246
self.layer,
232247
neuron_selector,
@@ -274,7 +289,10 @@ def __init__(
274289
r"""
275290
Args:
276291
277-
model (torch.nn.Module): The reference to PyTorch model instance.
292+
model (nn.Module): The reference to PyTorch model instance. Model cannot
293+
contain any in-place nonlinear submodules; these are not
294+
supported by the register_full_backward_hook PyTorch API
295+
starting from PyTorch v1.8.
278296
layer (torch.nn.Module): Layer for which neuron attributions are computed.
279297
Attributions for a particular neuron for the input or output
280298
of this layer are computed using the argument neuron_selector
@@ -448,7 +466,19 @@ def attribute(
448466
>>> # index (4,1,2).
449467
>>> attribution = dl.attribute(input, (4,1,2))
450468
"""
469+
451470
dl = DeepLiftShap(cast(Module, self.forward_func), self.multiplies_by_inputs)
471+
if not attribute_to_neuron_input:
472+
warnings.warn(
473+
"Attribution to neuron output is no longer supported for"
474+
" NeuronDeepLiftShap and will be deprecated in Captum"
475+
" 0.6.0 due to changes in PyTorch's full backward hook"
476+
" behavior. To obtain attributions for a neuron's"
477+
" output, please attribute with respect to the next layer's input"
478+
)
479+
dl.skip_new_hook_layer = self.layer # type: ignore
480+
else:
481+
dl.skip_new_hook_layer = None # type: ignore
452482
dl.gradient_func = construct_neuron_grad_fn(
453483
self.layer,
454484
neuron_selector,

captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
import warnings
23
from typing import Any, Callable, List, Tuple, Union
34

45
from captum._utils.gradient import construct_neuron_grad_fn
@@ -34,7 +35,10 @@ def __init__(
3435
r"""
3536
Args:
3637
37-
model (Module): The reference to PyTorch model instance.
38+
model (nn.Module): The reference to PyTorch model instance. Model cannot
39+
contain any in-place ReLU submodules; these are not
40+
supported by the register_full_backward_hook PyTorch API
41+
starting from PyTorch v1.8.
3842
layer (Module): Layer for which attributions are computed.
3943
Output size of attribute matches this layer's input or
4044
output dimensions, depending on whether we attribute to
@@ -159,6 +163,18 @@ def attribute(
159163
>>> # index (4,1,2).
160164
>>> attribution = neuron_deconv.attribute(input, (4,1,2))
161165
"""
166+
if not attribute_to_neuron_input:
167+
warnings.warn(
168+
"Attribution to neuron output is no longer supported for"
169+
" NeuronDeconvolution and will be deprecated in Captum"
170+
" 0.6.0 due to changes in PyTorch's full backward hook"
171+
" behavior. To obtain attributions for a neuron's"
172+
" output, please attribute with respect to the next layer's input"
173+
)
174+
self.deconv.skip_new_hook_layer = self.layer # type: ignore
175+
else:
176+
self.deconv.skip_new_hook_layer = None # type: ignore
177+
162178
self.deconv.gradient_func = construct_neuron_grad_fn(
163179
self.layer, neuron_selector, self.device_ids, attribute_to_neuron_input
164180
)
@@ -191,7 +207,10 @@ def __init__(
191207
r"""
192208
Args:
193209
194-
model (Module): The reference to PyTorch model instance.
210+
model (nn.Module): The reference to PyTorch model instance. Model cannot
211+
contain any in-place ReLU submodules; these are not
212+
supported by the register_full_backward_hook PyTorch API
213+
starting from PyTorch v1.8.
195214
layer (Module): Layer for which neuron attributions are computed.
196215
Attributions for a particular neuron in the output of
197216
this layer are computed using the argument neuron_selector
@@ -313,6 +332,18 @@ def attribute(
313332
>>> # index (4,1,2).
314333
>>> attribution = neuron_gb.attribute(input, (4,1,2))
315334
"""
335+
if not attribute_to_neuron_input:
336+
warnings.warn(
337+
"Attribution to neuron output is no longer supported for"
338+
" NeuronGuidedBackprop and will be deprecated in Captum"
339+
" 0.6.0 due to changes in PyTorch's full backward hook"
340+
" behavior. To obtain attributions for a neuron's"
341+
" output, please attribute with respect to the next layer's input"
342+
)
343+
self.guided_backprop.skip_new_hook_layer = self.layer # type: ignore
344+
else:
345+
self.guided_backprop.skip_new_hook_layer = None # type: ignore
346+
316347
self.guided_backprop.gradient_func = construct_neuron_grad_fn(
317348
self.layer, neuron_selector, self.device_ids, attribute_to_neuron_input
318349
)

tests/attr/layer/test_layer_deeplift.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
class TestDeepLift(BaseTest):
2727
def test_relu_layer_deeplift(self) -> None:
28-
model = ReLULinearModel(inplace=True)
28+
model = ReLULinearModel(inplace=False)
2929
inputs, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing()
3030

3131
layer_dl = LayerDeepLift(model, model.relu)
@@ -39,7 +39,7 @@ def test_relu_layer_deeplift(self) -> None:
3939
assert_delta(self, delta)
4040

4141
def test_relu_layer_deeplift_wo_mutliplying_by_inputs(self) -> None:
42-
model = ReLULinearModel(inplace=True)
42+
model = ReLULinearModel(inplace=False)
4343
inputs, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing()
4444

4545
layer_dl = LayerDeepLift(model, model.relu, multiply_by_inputs=False)
@@ -83,7 +83,7 @@ def test_relu_layer_deeplift_add_args(self) -> None:
8383
assert_delta(self, delta)
8484

8585
def test_linear_layer_deeplift(self) -> None:
86-
model = ReLULinearModel(inplace=True)
86+
model = ReLULinearModel(inplace=False)
8787
inputs, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing()
8888

8989
layer_dl = LayerDeepLift(model, model.l3)
@@ -103,7 +103,7 @@ def test_relu_deeplift_with_custom_attr_func(self) -> None:
103103
self._relu_custom_attr_func_assert(attr_method, inputs, baselines, [[2.0]])
104104

105105
def test_inplace_maxpool_relu_with_custom_attr_func(self) -> None:
106-
model = BasicModel_MaxPool_ReLU(inplace=True)
106+
model = BasicModel_MaxPool_ReLU(inplace=False)
107107
inp = torch.tensor([[[1.0, 2.0, -4.0], [-3.0, -2.0, -1.0]]])
108108
dl = LayerDeepLift(model, model.maxpool)
109109

@@ -116,7 +116,7 @@ def custom_att_func(mult, inp, baseline):
116116
dl.attribute(inp, custom_attribution_func=custom_att_func)
117117

118118
def test_linear_layer_deeplift_batch(self) -> None:
119-
model = ReLULinearModel(inplace=True)
119+
model = ReLULinearModel(inplace=False)
120120
_, baselines = _create_inps_and_base_for_deeplift_neuron_layer_testing()
121121
x1 = torch.tensor(
122122
[[-10.0, 1.0, -5.0], [-10.0, 1.0, -5.0], [-10.0, 1.0, -5.0]],
@@ -197,7 +197,7 @@ def test_relu_layer_deepliftshap_multiple_output(self) -> None:
197197
assert_delta(self, delta)
198198

199199
def test_linear_layer_deepliftshap(self) -> None:
200-
model = ReLULinearModel(inplace=True)
200+
model = ReLULinearModel(inplace=False)
201201
(
202202
inputs,
203203
baselines,

0 commit comments

Comments
 (0)