@@ -103,7 +103,12 @@ class DeepLift(GradientAttribution):
103
103
https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
104
104
"""
105
105
106
- def __init__ (self , model : Module , multiply_by_inputs : bool = True ) -> None :
106
+ def __init__ (
107
+ self ,
108
+ model : Module ,
109
+ multiply_by_inputs : bool = True ,
110
+ eps : float = 1e-10 ,
111
+ ) -> None :
107
112
r"""
108
113
Args:
109
114
@@ -123,9 +128,16 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
123
128
are being multiplied by (inputs - baselines).
124
129
This flag applies only if `custom_attribution_func` is
125
130
set to None.
131
+
132
+ eps (float, optional): A value at which to consider output/input change
133
+ significant when computing the gradients for non-linear layers.
134
+ This is useful to adjust, depending on your model's bit depth,
135
+ to avoid numerical issues during the gradient computation.
136
+ Default: 1e-10
126
137
"""
127
138
GradientAttribution .__init__ (self , model )
128
139
self .model = model
140
+ self .eps = eps
129
141
self .forward_handles : List [RemovableHandle ] = []
130
142
self .backward_handles : List [RemovableHandle ] = []
131
143
self ._multiply_by_inputs = multiply_by_inputs
@@ -322,7 +334,6 @@ def attribute( # type: ignore
322
334
activations. The hooks and attributes will be removed
323
335
after the attribution is finished"""
324
336
)
325
-
326
337
baselines = _tensorize_baseline (inputs , baselines )
327
338
main_model_hooks = []
328
339
try :
@@ -471,7 +482,6 @@ def _backward_hook(
471
482
module : Module ,
472
483
grad_input : Union [Tensor , Tuple [Tensor , ...]],
473
484
grad_output : Union [Tensor , Tuple [Tensor , ...]],
474
- eps : float = 1e-10 ,
475
485
):
476
486
r"""
477
487
`grad_input` is the gradient of the neuron with respect to its input
@@ -495,7 +505,12 @@ def _backward_hook(
495
505
)
496
506
multipliers = tuple (
497
507
SUPPORTED_NON_LINEAR [type (module )](
498
- module , module .input , module .output , grad_input , grad_output , eps = eps
508
+ module ,
509
+ module .input ,
510
+ module .output ,
511
+ grad_input ,
512
+ grad_output ,
513
+ eps = self .eps ,
499
514
)
500
515
)
501
516
# remove all the properies that we set for the inputs and output
0 commit comments