@@ -555,73 +555,9 @@ def attribute_future(
555
555
]
556
556
)
557
557
558
- def eval_fut_to_ablated_out_fut (
559
- # pyre-ignore Invalid type parameters [24]
560
- eval_futs : Future [List [Future [List [object ]]]],
561
- current_inputs : Tuple [Tensor , ...],
562
- current_mask : Tensor ,
563
- i : int ,
564
- perturbations_per_eval : int ,
565
- num_examples : int ,
566
- formatted_inputs : Tuple [Tensor , ...],
567
- ) -> Tuple [List [Tensor ], List [Tensor ]]:
568
- try :
569
- modified_eval = cast (Tensor , eval_futs .value ()[1 ].value ())
570
- initial_eval_tuple = cast (
571
- Tuple [
572
- List [Tensor ],
573
- List [Tensor ],
574
- Tensor ,
575
- Tensor ,
576
- int ,
577
- dtype ,
578
- ],
579
- eval_futs .value ()[0 ].value (),
580
- )
581
- if len (initial_eval_tuple ) != 6 :
582
- raise AssertionError (
583
- "eval_fut_to_ablated_out_fut: "
584
- "initial_eval_tuple should have 6 elements: "
585
- "total_attrib, weights, initial_eval, "
586
- "flattened_initial_eval, n_outputs, attrib_type "
587
- )
588
- if not isinstance (modified_eval , Tensor ):
589
- raise AssertionError (
590
- "eval_fut_to_ablated_out_fut: "
591
- "modified eval should be a Tensor"
592
- )
593
- (
594
- total_attrib ,
595
- weights ,
596
- initial_eval ,
597
- flattened_initial_eval ,
598
- n_outputs ,
599
- attrib_type ,
600
- ) = initial_eval_tuple
601
- result = self ._process_ablated_out ( # type: ignore # noqa: E501 line too long
602
- modified_eval = modified_eval ,
603
- current_inputs = current_inputs ,
604
- current_mask = current_mask ,
605
- perturbations_per_eval = perturbations_per_eval ,
606
- num_examples = num_examples ,
607
- initial_eval = initial_eval ,
608
- flattened_initial_eval = flattened_initial_eval ,
609
- inputs = formatted_inputs ,
610
- n_outputs = n_outputs ,
611
- total_attrib = total_attrib ,
612
- weights = weights ,
613
- i = i ,
614
- attrib_type = attrib_type ,
615
- )
616
- except FeatureAblationFutureError as e :
617
- raise FeatureAblationFutureError (
618
- "eval_fut_to_ablated_out_fut func failed)"
619
- ) from e
620
- return result
621
-
622
558
ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = (
623
559
eval_futs .then (
624
- lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
560
+ lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self . _eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
625
561
eval_futs = eval_futs ,
626
562
current_inputs = current_inputs ,
627
563
current_mask = current_mask ,
@@ -660,6 +596,70 @@ def _attribute_progress_setup(
660
596
)
661
597
return attr_progress
662
598
599
+ def _eval_fut_to_ablated_out_fut (
600
+ self ,
601
+ # pyre-ignore Invalid type parameters [24]
602
+ eval_futs : Future [List [Future [List [object ]]]],
603
+ current_inputs : Tuple [Tensor , ...],
604
+ current_mask : Tensor ,
605
+ i : int ,
606
+ perturbations_per_eval : int ,
607
+ num_examples : int ,
608
+ formatted_inputs : Tuple [Tensor , ...],
609
+ ) -> Tuple [List [Tensor ], List [Tensor ]]:
610
+ try :
611
+ modified_eval = cast (Tensor , eval_futs .value ()[1 ].value ())
612
+ initial_eval_tuple = cast (
613
+ Tuple [
614
+ List [Tensor ],
615
+ List [Tensor ],
616
+ Tensor ,
617
+ Tensor ,
618
+ int ,
619
+ dtype ,
620
+ ],
621
+ eval_futs .value ()[0 ].value (),
622
+ )
623
+ if len (initial_eval_tuple ) != 6 :
624
+ raise AssertionError (
625
+ "eval_fut_to_ablated_out_fut: "
626
+ "initial_eval_tuple should have 6 elements: "
627
+ "total_attrib, weights, initial_eval, "
628
+ "flattened_initial_eval, n_outputs, attrib_type "
629
+ )
630
+ if not isinstance (modified_eval , Tensor ):
631
+ raise AssertionError (
632
+ "eval_fut_to_ablated_out_fut: " "modified eval should be a Tensor"
633
+ )
634
+ (
635
+ total_attrib ,
636
+ weights ,
637
+ initial_eval ,
638
+ flattened_initial_eval ,
639
+ n_outputs ,
640
+ attrib_type ,
641
+ ) = initial_eval_tuple
642
+ result = self ._process_ablated_out ( # type: ignore # noqa: E501 line too long
643
+ modified_eval = modified_eval ,
644
+ current_inputs = current_inputs ,
645
+ current_mask = current_mask ,
646
+ perturbations_per_eval = perturbations_per_eval ,
647
+ num_examples = num_examples ,
648
+ initial_eval = initial_eval ,
649
+ flattened_initial_eval = flattened_initial_eval ,
650
+ inputs = formatted_inputs ,
651
+ n_outputs = n_outputs ,
652
+ total_attrib = total_attrib ,
653
+ weights = weights ,
654
+ i = i ,
655
+ attrib_type = attrib_type ,
656
+ )
657
+ except FeatureAblationFutureError as e :
658
+ raise FeatureAblationFutureError (
659
+ "eval_fut_to_ablated_out_fut func failed)"
660
+ ) from e
661
+ return result
662
+
663
663
# pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
664
664
def _ith_input_ablation_generator (
665
665
self ,
0 commit comments