Skip to content

Fix the logging of a nested dictionary metric in MLflow #8169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions monai/handlers/mlflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import Dataset

from monai.apps.utils import get_logger
from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, min_version, optional_import
from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
Expand Down Expand Up @@ -303,7 +303,9 @@ def _log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None

run_id = self.cur_run.info.run_id
timestamp = int(time.time() * 1000)
metrics_arr = [mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in metrics.items()]
metrics_arr = [
mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in flatten_dict(metrics).items()
]
self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])

def _parse_artifacts(self):
Expand Down
5 changes: 2 additions & 3 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

from monai.apps import get_logger
from monai.utils import IgniteInfo, is_scalar, min_version, optional_import
from monai.utils import IgniteInfo, flatten_dict, is_scalar, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
Expand Down Expand Up @@ -211,8 +211,7 @@ def _default_epoch_print(self, engine: Engine) -> None:

"""
current_epoch = self.global_epoch_transform(engine.state.epoch)

prints_dict = engine.state.metrics
prints_dict = flatten_dict(engine.state.metrics)
if prints_dict is not None and len(prints_dict) > 0:
out_str = f"Epoch[{current_epoch}] Metrics -- "
for name in sorted(prints_dict):
Expand Down
15 changes: 8 additions & 7 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from __future__ import annotations

import io
import os
import re
import tempfile
import warnings
from collections import OrderedDict
from collections.abc import Callable, Mapping, Sequence
Expand Down Expand Up @@ -688,27 +690,26 @@ def convert_to_onnx(
onnx_inputs = (inputs,)
else:
onnx_inputs = tuple(inputs)

if filename is None:
f = io.BytesIO()
temp_file = tempfile.NamedTemporaryFile(delete=False)
f = f"{temp_file.name}/model.onnx"
else:
f = filename

torch.onnx.export(
mode_to_export,
onnx_inputs,
f=f, # type: ignore[arg-type]
f=f,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
**torch_versioned_kwargs,
)
if filename is None:
onnx_model = onnx.load_model_from_string(f.getvalue())
else:
onnx_model = onnx.load(filename)
onnx_model = onnx.load(f)
temp_file.close()
os.remove(temp_file.name)

if do_constant_folding and polygraphy_imported:
from polygraphy.backend.onnx.loader import fold_constants
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
ensure_tuple_size,
fall_back_tuple,
first,
flatten_dict,
get_seed,
has_option,
is_immutable,
Expand Down
13 changes: 13 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,3 +916,16 @@ def unsqueeze_right(arr: NT, ndim: int) -> NT:
def unsqueeze_left(arr: NT, ndim: int) -> NT:
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(None,) * (ndim - arr.ndim)]


def flatten_dict(metrics: dict[str, Any]) -> dict[str, Any]:
"""
Flatten the nested dictionary to a flat dictionary.
"""
result = {}
for key, value in metrics.items():
if isinstance(value, dict):
result.update(flatten_dict(value))
else:
result[key] = value
return result
9 changes: 8 additions & 1 deletion tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def _train_func(engine, batch):
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
# log nested metrics
engine.state.metrics["acc_per_label"] = {
"label_0": current_metric + 0.1,
"label_1": current_metric + 0.2,
}
engine.state.test = current_metric

# set up testing handler
Expand All @@ -138,10 +143,12 @@ def _update_metric(engine):
state_attributes=["test"],
experiment_param=experiment_param,
artifacts=[artifact_path],
close_on_complete=True,
close_on_complete=False,
)
handler.attach(engine)
engine.run(range(3), max_epochs=2)
cur_run = handler.client.get_run(handler.cur_run.info.run_id)
self.assertTrue("label_0" in cur_run.data.metrics.keys())
handler.close()
# check logging output
self.assertTrue(len(glob.glob(test_path)) > 0)
Expand Down
Loading