From 743887868a74dd3e3f4487a01884bd8cff54239f Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:32:37 +0800 Subject: [PATCH 1/7] fix model-zoo#697 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/handlers/mlflow_handler.py | 6 ++++-- monai/handlers/stats_handler.py | 5 ++--- monai/utils/__init__.py | 1 + monai/utils/misc.py | 13 +++++++++++++ 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index c7e293ea7d..3078d89f97 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -22,7 +22,7 @@ from torch.utils.data import Dataset from monai.apps.utils import get_logger -from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, min_version, optional_import +from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.") @@ -303,7 +303,9 @@ def _log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None run_id = self.cur_run.info.run_id timestamp = int(time.time() * 1000) - metrics_arr = [mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in metrics.items()] + metrics_arr = [ + mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in flatten_dict(metrics).items() + ] self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[]) def _parse_artifacts(self): diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index ab36d19bd1..d294e448c1 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -19,7 +19,7 @@ import torch from monai.apps import get_logger -from monai.utils import IgniteInfo, is_scalar, min_version, optional_import +from monai.utils import IgniteInfo, flatten_dict, is_scalar, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: @@ -211,8 +211,7 @@ def _default_epoch_print(self, engine: Engine) -> None: """ current_epoch = self.global_epoch_transform(engine.state.epoch) - - prints_dict = engine.state.metrics + prints_dict = flatten_dict(engine.state.metrics) if prints_dict is not None and len(prints_dict) > 0: out_str = f"Epoch[{current_epoch}] Metrics -- " for name in sorted(prints_dict): diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 916c1a6c70..79dc1f2304 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -78,6 +78,7 @@ ensure_tuple_size, fall_back_tuple, first, + flatten_dict, get_seed, has_option, is_immutable, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 6386aae713..07b8896419 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -916,3 +916,16 @@ def unsqueeze_right(arr: NT, ndim: int) -> NT: def unsqueeze_left(arr: NT, ndim: int) -> NT: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] + + +def flatten_dict(metrics: dict[str, Any]) -> dict[str, Any]: + """ + Flatten the nested dictionary to a flat dictionary. + """ + result = {} + for key, value in metrics.items(): + if isinstance(value, dict): + result.update(flatten_dict(value)) + else: + result[f"{key}"] = value + return result From 9178c4afc00231e09c8d4d991188f70236fcd88e Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:52:29 +0800 Subject: [PATCH 2/7] address comments and add test Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/utils/misc.py | 2 +- tests/test_handler_mlflow.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 07b8896419..d4b8af0fb1 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -927,5 +927,5 @@ def flatten_dict(metrics: dict[str, Any]) -> dict[str, Any]: if isinstance(value, dict): result.update(flatten_dict(value)) else: - result[f"{key}"] = value + result[key] = value return result diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index 44adc49fc2..443047f95e 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -122,6 +122,8 @@ def _train_func(engine, batch): def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 + # log nested metrics + engine.state.metrics["acc_per_label"] = {"label_0": current_metric + 0.1, "label_1": current_metric + 0.2} engine.state.test = current_metric # set up testing handler @@ -138,10 +140,12 @@ def _update_metric(engine): state_attributes=["test"], experiment_param=experiment_param, artifacts=[artifact_path], - close_on_complete=True, + close_on_complete=False, ) handler.attach(engine) engine.run(range(3), max_epochs=2) + cur_run = handler.client.get_run(handler.cur_run.info.run_id) + self.assertTrue("label_0" in cur_run.data.metrics.keys()) handler.close() # check logging output self.assertTrue(len(glob.glob(test_path)) > 0) From f98497972073cfc0a0bf104b280fdc2345afa00e Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:53:57 +0800 Subject: [PATCH 3/7] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_handler_mlflow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index 443047f95e..36d59ff1bf 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -123,7 +123,10 @@ def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 # log nested metrics - engine.state.metrics["acc_per_label"] = {"label_0": current_metric + 0.1, "label_1": current_metric + 0.2} + engine.state.metrics["acc_per_label"] = { + "label_0": current_metric + 0.1, + "label_1": current_metric + 0.2, + } engine.state.test = current_metric # set up testing handler From ef6d1c867bddf964e6425613653213e6f5ee260d Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:26:35 +0800 Subject: [PATCH 4/7] try to fix #8149 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 295a055390..88d51408dd 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -15,7 +15,9 @@ from __future__ import annotations import io +import os import re +import tempfile import warnings from collections import OrderedDict from collections.abc import Callable, Mapping, Sequence @@ -688,16 +690,16 @@ def convert_to_onnx( onnx_inputs = (inputs,) else: onnx_inputs = tuple(inputs) - if filename is None: - f = io.BytesIO() + temp_file = tempfile.NamedTemporaryFile(delete=False) + f = f"{temp_file.name}/model.onnx" else: f = filename torch.onnx.export( mode_to_export, onnx_inputs, - f=f, # type: ignore[arg-type] + f=f, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, @@ -705,10 +707,9 @@ def convert_to_onnx( do_constant_folding=do_constant_folding, **torch_versioned_kwargs, ) - if filename is None: - onnx_model = onnx.load_model_from_string(f.getvalue()) - else: - onnx_model = onnx.load(filename) + onnx_model = onnx.load(f) + temp_file.close() + os.remove(temp_file.name) if do_constant_folding and polygraphy_imported: from polygraphy.backend.onnx.loader import fold_constants From f258416a3e43bad2d6137bbc26c78f5ee97a6890 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:31:41 +0800 Subject: [PATCH 5/7] fix ci Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 88d51408dd..14eece2c97 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -692,7 +692,7 @@ def convert_to_onnx( onnx_inputs = tuple(inputs) if filename is None: temp_file = tempfile.NamedTemporaryFile(delete=False) - f = f"{temp_file.name}/model.onnx" + f = temp_file.name else: f = filename From addce552027a3acf717739deadb69cc5670529aa Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:56:14 +0800 Subject: [PATCH 6/7] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 14eece2c97..d4d0f4a3f4 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -690,8 +690,9 @@ def convert_to_onnx( onnx_inputs = (inputs,) else: onnx_inputs = tuple(inputs) + temp_file = None if filename is None: - temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file = tempfile.NamedTemporaryFile() f = temp_file.name else: f = filename @@ -708,8 +709,6 @@ def convert_to_onnx( **torch_versioned_kwargs, ) onnx_model = onnx.load(f) - temp_file.close() - os.remove(temp_file.name) if do_constant_folding and polygraphy_imported: from polygraphy.backend.onnx.loader import fold_constants From 357abd58a74dc99234cbba33127e197408c95f7e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 02:57:45 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d4d0f4a3f4..cfad0364c3 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -15,7 +15,6 @@ from __future__ import annotations import io -import os import re import tempfile import warnings