@@ -143,7 +143,10 @@ class MetricConflictError(Exception):
143143
144144
145145def extract_accuracy_metrics (
146- artifacts_dir : Path , log_metrics : List [str ] = None , metric_sep : str = "_" , include_task_name : bool = True
146+ artifacts_dir : Path ,
147+ log_metrics : List [str ] = None ,
148+ metric_sep : str = "_" ,
149+ include_task_name : bool = True ,
147150) -> Dict [str , float ]:
148151 """Extract accuracy metrics from job results.
149152 artifacts_dir: Path to the artifacts directory
@@ -155,7 +158,11 @@ def extract_accuracy_metrics(
155158 if not artifacts_dir or not artifacts_dir .exists ():
156159 raise RuntimeError (f"Artifacts directory { artifacts_dir } not found" )
157160
158- metrics = _extract_from_results_yml (artifacts_dir / RESULTS_FILE , metric_sep = metric_sep , include_task_name = include_task_name )
161+ metrics = _extract_from_results_yml (
162+ artifacts_dir / RESULTS_FILE ,
163+ metric_sep = metric_sep ,
164+ include_task_name = include_task_name ,
165+ )
159166 if not log_metrics :
160167 return metrics
161168
@@ -559,14 +566,21 @@ def copy_artifacts(
559566# =============================================================================
560567
561568
562- def _extract_metrics_from_results (results : dict , metric_sep : str = "_" , include_task_name : bool = True ) -> Dict [str , float ]:
569+ def _extract_metrics_from_results (
570+ results : dict , metric_sep : str = "_" , include_task_name : bool = True
571+ ) -> Dict [str , float ]:
563572 """Extract metrics from a 'results' dict (with optional 'groups'/'tasks')."""
564573 metrics : Dict [str , float ] = {}
565574 for section in ["groups" , "tasks" ]:
566575 section_data = results .get (section )
567576 if isinstance (section_data , dict ):
568577 for task_name , task_data in section_data .items ():
569- task_metrics = _extract_task_metrics (task_name , task_data , metric_sep = metric_sep , include_task_name = include_task_name )
578+ task_metrics = _extract_task_metrics (
579+ task_name ,
580+ task_data ,
581+ metric_sep = metric_sep ,
582+ include_task_name = include_task_name ,
583+ )
570584 _safe_update_metrics (
571585 target = metrics ,
572586 source = task_metrics ,
@@ -591,18 +605,28 @@ def _extract_from_results_yml(
591605 if "results" not in data :
592606 raise ValueError (f"Failed to parse { results_yml } - no results section found" )
593607
594- return _extract_metrics_from_results (data ["results" ], metric_sep = metric_sep , include_task_name = include_task_name )
608+ return _extract_metrics_from_results (
609+ data ["results" ], metric_sep = metric_sep , include_task_name = include_task_name
610+ )
595611
596612
597- def _extract_task_metrics (task_name : str , task_data : dict , metric_sep : str = "_" , include_task_name : bool = True ) -> Dict [str , float ]:
613+ def _extract_task_metrics (
614+ task_name : str ,
615+ task_data : dict ,
616+ metric_sep : str = "_" ,
617+ include_task_name : bool = True ,
618+ ) -> Dict [str , float ]:
598619 """Extract metrics from a task's metrics data."""
599620 extracted = {}
600621
601622 metrics_data = task_data .get ("metrics" , {})
602623 if "groups" in task_data and task_data ["groups" ] is not None :
603624 for group_name , group_data in task_data ["groups" ].items ():
604625 group_extracted = _extract_task_metrics (
605- f"{ task_name } { metric_sep } { group_name } " if include_task_name else group_name , group_data
626+ f"{ task_name } { metric_sep } { group_name } "
627+ if include_task_name
628+ else group_name ,
629+ group_data ,
606630 )
607631 _safe_update_metrics (
608632 target = extracted ,
0 commit comments