@@ -76,7 +76,7 @@ def _get_context():
76
76
77
77
78
78
VisualizationOutput = namedtuple (
79
- "VisualizationOutput" , "feature_outputs actual predicted active_index"
79
+ "VisualizationOutput" , "feature_outputs actual predicted active_index model_index "
80
80
)
81
81
Contribution = namedtuple ("Contribution" , "name percent" )
82
82
SampleCache = namedtuple ("SampleCache" , "inputs additional_forward_args label" )
@@ -149,11 +149,8 @@ def __init__(
149
149
r"""
150
150
Args:
151
151
152
- models (torch.nn.module): PyTorch module (model ) for attribution
152
+ models (torch.nn.module): One or more PyTorch modules (models ) for attribution
153
153
visualization.
154
- We plan to support visualizing and comparing multiple models
155
- in the future, but currently this supports only a single
156
- model.
157
154
classes (list of string): List of strings corresponding to the names of
158
155
classes for classification.
159
156
features (list of BaseFeature): List of BaseFeatures, which correspond
@@ -195,6 +192,7 @@ class scores.
195
192
self .classes = classes
196
193
self .features = features
197
194
self .dataset = dataset
195
+ self .models = models
198
196
self .attribution_calculation = AttributionCalculation (
199
197
models , classes , features , score_func , use_label_for_attr
200
198
)
@@ -203,12 +201,16 @@ class scores.
203
201
self ._dataset_iter = iter (dataset )
204
202
205
203
def _calculate_attribution_from_cache (
206
- self , index : int , target : Optional [Tensor ]
204
+ self , input_index : int , model_index : int , target : Optional [Tensor ]
207
205
) -> Optional [VisualizationOutput ]:
208
- c = self ._outputs [index ][1 ]
206
+ c = self ._outputs [input_index ][1 ]
209
207
return self ._calculate_vis_output (
210
- c .inputs , c .additional_forward_args , c .label , torch .tensor (target )
211
- )
208
+ c .inputs ,
209
+ c .additional_forward_args ,
210
+ c .label ,
211
+ torch .tensor (target ),
212
+ model_index ,
213
+ )[0 ]
212
214
213
215
def _update_config (self , settings ):
214
216
self ._config = FilterConfig (
@@ -344,67 +346,97 @@ def _should_keep_prediction(
344
346
return True
345
347
346
348
def _calculate_vis_output (
347
- self , inputs , additional_forward_args , label , target = None
348
- ) -> Optional [VisualizationOutput ]:
349
- actual_label_output = None
350
- if label is not None and len (label ) > 0 :
351
- label_index = int (label [0 ])
352
- actual_label_output = OutputScore (
353
- score = 100 , index = label_index , label = self .classes [label_index ]
354
- )
355
-
356
- (
357
- predicted_scores ,
358
- baselines ,
359
- transformed_inputs ,
360
- ) = self .attribution_calculation .calculate_predicted_scores (
361
- inputs , additional_forward_args
349
+ self ,
350
+ inputs ,
351
+ additional_forward_args ,
352
+ label ,
353
+ target = None ,
354
+ single_model_index = None ,
355
+ ) -> Optional [List [VisualizationOutput ]]:
356
+ # Use all models, unless the user wants to render data for a particular one
357
+ models_used = (
358
+ [self .models [single_model_index ]]
359
+ if single_model_index is not None
360
+ else self .models
362
361
)
362
+ results = []
363
+ for model_index , model in enumerate (models_used ):
364
+ # Get list of model visualizations for each input
365
+ actual_label_output = None
366
+ if label is not None and len (label ) > 0 :
367
+ label_index = int (label [0 ])
368
+ actual_label_output = OutputScore (
369
+ score = 100 , index = label_index , label = self .classes [label_index ]
370
+ )
371
+
372
+ (
373
+ predicted_scores ,
374
+ baselines ,
375
+ transformed_inputs ,
376
+ ) = self .attribution_calculation .calculate_predicted_scores (
377
+ inputs , additional_forward_args , model
378
+ )
363
379
364
- # Filter based on UI configuration
365
- if actual_label_output is None or not self ._should_keep_prediction (
366
- predicted_scores , actual_label_output
367
- ):
368
- return None
369
-
370
- if target is None :
371
- target = predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
372
-
373
- # attributions are given per input*
374
- # inputs given to the model are described via `self.features`
375
- #
376
- # *an input contains multiple features that represent it
377
- # e.g. all the pixels that describe an image is an input
378
-
379
- attrs_per_input_feature = self .attribution_calculation .calculate_attribution (
380
- baselines ,
381
- transformed_inputs ,
382
- additional_forward_args ,
383
- target ,
384
- self ._config .attribution_method ,
385
- self ._config .attribution_arguments ,
386
- )
380
+ # Filter based on UI configuration
381
+ if actual_label_output is None or not self ._should_keep_prediction (
382
+ predicted_scores , actual_label_output
383
+ ):
384
+ continue
385
+
386
+ if target is None :
387
+ target = (
388
+ predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
389
+ )
390
+
391
+ # attributions are given per input*
392
+ # inputs given to the model are described via `self.features`
393
+ #
394
+ # *an input contains multiple features that represent it
395
+ # e.g. all the pixels that describe an image is an input
396
+
397
+ attrs_per_input_feature = (
398
+ self .attribution_calculation .calculate_attribution (
399
+ baselines ,
400
+ transformed_inputs ,
401
+ additional_forward_args ,
402
+ target ,
403
+ self ._config .attribution_method ,
404
+ self ._config .attribution_arguments ,
405
+ model ,
406
+ )
407
+ )
387
408
388
- net_contrib = self .attribution_calculation .calculate_net_contrib (
389
- attrs_per_input_feature
390
- )
409
+ net_contrib = self .attribution_calculation .calculate_net_contrib (
410
+ attrs_per_input_feature
411
+ )
391
412
392
- # the features per input given
393
- features_per_input = [
394
- feature .visualize (attr , data , contrib )
395
- for feature , attr , data , contrib in zip (
396
- self .features , attrs_per_input_feature , inputs , net_contrib
413
+ # the features per input given
414
+ features_per_input = [
415
+ feature .visualize (attr , data , contrib )
416
+ for feature , attr , data , contrib in zip (
417
+ self .features , attrs_per_input_feature , inputs , net_contrib
418
+ )
419
+ ]
420
+
421
+ results .append (
422
+ VisualizationOutput (
423
+ feature_outputs = features_per_input ,
424
+ actual = actual_label_output ,
425
+ predicted = predicted_scores ,
426
+ active_index = target
427
+ if target is not None
428
+ else actual_label_output .index ,
429
+ # Even if we only iterated over one model, the index should be fixed
430
+ # to show the index the model would have had in the list
431
+ model_index = single_model_index
432
+ if single_model_index is not None
433
+ else model_index ,
434
+ )
397
435
)
398
- ]
399
436
400
- return VisualizationOutput (
401
- feature_outputs = features_per_input ,
402
- actual = actual_label_output ,
403
- predicted = predicted_scores ,
404
- active_index = target if target is not None else actual_label_output .index ,
405
- )
437
+ return results if results else None
406
438
407
- def _get_outputs (self ) -> List [Tuple [VisualizationOutput , SampleCache ]]:
439
+ def _get_outputs (self ) -> List [Tuple [List [ VisualizationOutput ] , SampleCache ]]:
408
440
batch_data = next (self ._dataset_iter )
409
441
vis_outputs = []
410
442
0 commit comments