Skip to content

Commit 692053a

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Reduce complexity of FeatureAblation.attribute_future (pytorch#1368)
Summary: Reduce complexity of FeatureAblation.attribute_future by refactoring inner function Reviewed By: cyrjano Differential Revision: D64361191
1 parent fd758e0 commit 692053a

File tree

1 file changed

+65
-65
lines changed

1 file changed

+65
-65
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -555,73 +555,9 @@ def attribute_future(
555555
]
556556
)
557557

558-
def eval_fut_to_ablated_out_fut(
559-
# pyre-ignore Invalid type parameters [24]
560-
eval_futs: Future[List[Future[List[object]]]],
561-
current_inputs: Tuple[Tensor, ...],
562-
current_mask: Tensor,
563-
i: int,
564-
perturbations_per_eval: int,
565-
num_examples: int,
566-
formatted_inputs: Tuple[Tensor, ...],
567-
) -> Tuple[List[Tensor], List[Tensor]]:
568-
try:
569-
modified_eval = cast(Tensor, eval_futs.value()[1].value())
570-
initial_eval_tuple = cast(
571-
Tuple[
572-
List[Tensor],
573-
List[Tensor],
574-
Tensor,
575-
Tensor,
576-
int,
577-
dtype,
578-
],
579-
eval_futs.value()[0].value(),
580-
)
581-
if len(initial_eval_tuple) != 6:
582-
raise AssertionError(
583-
"eval_fut_to_ablated_out_fut: "
584-
"initial_eval_tuple should have 6 elements: "
585-
"total_attrib, weights, initial_eval, "
586-
"flattened_initial_eval, n_outputs, attrib_type "
587-
)
588-
if not isinstance(modified_eval, Tensor):
589-
raise AssertionError(
590-
"eval_fut_to_ablated_out_fut: "
591-
"modified eval should be a Tensor"
592-
)
593-
(
594-
total_attrib,
595-
weights,
596-
initial_eval,
597-
flattened_initial_eval,
598-
n_outputs,
599-
attrib_type,
600-
) = initial_eval_tuple
601-
result = self._process_ablated_out( # type: ignore # noqa: E501 line too long
602-
modified_eval=modified_eval,
603-
current_inputs=current_inputs,
604-
current_mask=current_mask,
605-
perturbations_per_eval=perturbations_per_eval,
606-
num_examples=num_examples,
607-
initial_eval=initial_eval,
608-
flattened_initial_eval=flattened_initial_eval,
609-
inputs=formatted_inputs,
610-
n_outputs=n_outputs,
611-
total_attrib=total_attrib,
612-
weights=weights,
613-
i=i,
614-
attrib_type=attrib_type,
615-
)
616-
except FeatureAblationFutureError as e:
617-
raise FeatureAblationFutureError(
618-
"eval_fut_to_ablated_out_fut func failed)"
619-
) from e
620-
return result
621-
622558
ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
623559
eval_futs.then(
624-
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
560+
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
625561
eval_futs=eval_futs,
626562
current_inputs=current_inputs,
627563
current_mask=current_mask,
@@ -660,6 +596,70 @@ def _attribute_progress_setup(
660596
)
661597
return attr_progress
662598

599+
def _eval_fut_to_ablated_out_fut(
600+
self,
601+
# pyre-ignore Invalid type parameters [24]
602+
eval_futs: Future[List[Future[List[object]]]],
603+
current_inputs: Tuple[Tensor, ...],
604+
current_mask: Tensor,
605+
i: int,
606+
perturbations_per_eval: int,
607+
num_examples: int,
608+
formatted_inputs: Tuple[Tensor, ...],
609+
) -> Tuple[List[Tensor], List[Tensor]]:
610+
try:
611+
modified_eval = cast(Tensor, eval_futs.value()[1].value())
612+
initial_eval_tuple = cast(
613+
Tuple[
614+
List[Tensor],
615+
List[Tensor],
616+
Tensor,
617+
Tensor,
618+
int,
619+
dtype,
620+
],
621+
eval_futs.value()[0].value(),
622+
)
623+
if len(initial_eval_tuple) != 6:
624+
raise AssertionError(
625+
"eval_fut_to_ablated_out_fut: "
626+
"initial_eval_tuple should have 6 elements: "
627+
"total_attrib, weights, initial_eval, "
628+
"flattened_initial_eval, n_outputs, attrib_type "
629+
)
630+
if not isinstance(modified_eval, Tensor):
631+
raise AssertionError(
632+
"eval_fut_to_ablated_out_fut: " "modified eval should be a Tensor"
633+
)
634+
(
635+
total_attrib,
636+
weights,
637+
initial_eval,
638+
flattened_initial_eval,
639+
n_outputs,
640+
attrib_type,
641+
) = initial_eval_tuple
642+
result = self._process_ablated_out( # type: ignore # noqa: E501 line too long
643+
modified_eval=modified_eval,
644+
current_inputs=current_inputs,
645+
current_mask=current_mask,
646+
perturbations_per_eval=perturbations_per_eval,
647+
num_examples=num_examples,
648+
initial_eval=initial_eval,
649+
flattened_initial_eval=flattened_initial_eval,
650+
inputs=formatted_inputs,
651+
n_outputs=n_outputs,
652+
total_attrib=total_attrib,
653+
weights=weights,
654+
i=i,
655+
attrib_type=attrib_type,
656+
)
657+
except FeatureAblationFutureError as e:
658+
raise FeatureAblationFutureError(
659+
"eval_fut_to_ablated_out_fut func failed)"
660+
) from e
661+
return result
662+
663663
# pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
664664
def _ith_input_ablation_generator(
665665
self,

0 commit comments

Comments
 (0)