Skip to content

Commit 9076f54

Browse files
committed
run linter
Signed-off-by: Marta Stepniewska-Dziubinska <martas@nvidia.com>
1 parent 6074973 commit 9076f54

File tree

2 files changed

+36
-8
lines changed
  • packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/exporters

2 files changed

+36
-8
lines changed

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/exporters/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,11 @@ def prepare_data_for_export(self, job_data: JobData) -> DataForExport | None:
352352
return None
353353

354354
try:
355-
metrics = extract_accuracy_metrics(artifacts_dir, metric_sep=self.config.metric_sep, include_task_name=self.config.include_task_name)
355+
metrics = extract_accuracy_metrics(
356+
artifacts_dir,
357+
metric_sep=self.config.metric_sep,
358+
include_task_name=self.config.include_task_name,
359+
)
356360
harness, task = load_benchmark_info(artifacts_dir)
357361
container = job_data.data.get("eval_image", None)
358362
model_id = get_model_id(artifacts_dir)

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/exporters/utils.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ class MetricConflictError(Exception):
143143

144144

145145
def extract_accuracy_metrics(
146-
artifacts_dir: Path, log_metrics: List[str] = None, metric_sep: str = "_", include_task_name: bool = True
146+
artifacts_dir: Path,
147+
log_metrics: List[str] = None,
148+
metric_sep: str = "_",
149+
include_task_name: bool = True,
147150
) -> Dict[str, float]:
148151
"""Extract accuracy metrics from job results.
149152
artifacts_dir: Path to the artifacts directory
@@ -155,7 +158,11 @@ def extract_accuracy_metrics(
155158
if not artifacts_dir or not artifacts_dir.exists():
156159
raise RuntimeError(f"Artifacts directory {artifacts_dir} not found")
157160

158-
metrics = _extract_from_results_yml(artifacts_dir / RESULTS_FILE, metric_sep=metric_sep, include_task_name=include_task_name)
161+
metrics = _extract_from_results_yml(
162+
artifacts_dir / RESULTS_FILE,
163+
metric_sep=metric_sep,
164+
include_task_name=include_task_name,
165+
)
159166
if not log_metrics:
160167
return metrics
161168

@@ -559,14 +566,21 @@ def copy_artifacts(
559566
# =============================================================================
560567

561568

562-
def _extract_metrics_from_results(results: dict, metric_sep: str = "_", include_task_name: bool = True) -> Dict[str, float]:
569+
def _extract_metrics_from_results(
570+
results: dict, metric_sep: str = "_", include_task_name: bool = True
571+
) -> Dict[str, float]:
563572
"""Extract metrics from a 'results' dict (with optional 'groups'/'tasks')."""
564573
metrics: Dict[str, float] = {}
565574
for section in ["groups", "tasks"]:
566575
section_data = results.get(section)
567576
if isinstance(section_data, dict):
568577
for task_name, task_data in section_data.items():
569-
task_metrics = _extract_task_metrics(task_name, task_data, metric_sep=metric_sep, include_task_name=include_task_name)
578+
task_metrics = _extract_task_metrics(
579+
task_name,
580+
task_data,
581+
metric_sep=metric_sep,
582+
include_task_name=include_task_name,
583+
)
570584
_safe_update_metrics(
571585
target=metrics,
572586
source=task_metrics,
@@ -591,18 +605,28 @@ def _extract_from_results_yml(
591605
if "results" not in data:
592606
raise ValueError(f"Failed to parse {results_yml} - no results section found")
593607

594-
return _extract_metrics_from_results(data["results"], metric_sep=metric_sep, include_task_name=include_task_name)
608+
return _extract_metrics_from_results(
609+
data["results"], metric_sep=metric_sep, include_task_name=include_task_name
610+
)
595611

596612

597-
def _extract_task_metrics(task_name: str, task_data: dict, metric_sep: str = "_", include_task_name: bool = True) -> Dict[str, float]:
613+
def _extract_task_metrics(
614+
task_name: str,
615+
task_data: dict,
616+
metric_sep: str = "_",
617+
include_task_name: bool = True,
618+
) -> Dict[str, float]:
598619
"""Extract metrics from a task's metrics data."""
599620
extracted = {}
600621

601622
metrics_data = task_data.get("metrics", {})
602623
if "groups" in task_data and task_data["groups"] is not None:
603624
for group_name, group_data in task_data["groups"].items():
604625
group_extracted = _extract_task_metrics(
605-
f"{task_name}{metric_sep}{group_name}" if include_task_name else group_name, group_data
626+
f"{task_name}{metric_sep}{group_name}"
627+
if include_task_name
628+
else group_name,
629+
group_data,
606630
)
607631
_safe_update_metrics(
608632
target=extracted,

0 commit comments

Comments
 (0)