From 13bdec5130709456abb3721450961801ab3bb4bc Mon Sep 17 00:00:00 2001 From: simben Date: Tue, 28 Jan 2025 19:16:46 +0000 Subject: [PATCH 01/67] Add nnUNet integration and corresponding unit tests --- monai/bundle/__init__.py | 1 + monai/bundle/nnunet.py | 338 ++++++++++++++++++++++++ tests/test_integration_nnunet_bundle.py | 112 ++++++++ 3 files changed, 451 insertions(+) create mode 100644 monai/bundle/nnunet.py create mode 100644 tests/test_integration_nnunet_bundle.py diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 3f3c8d545e..5d6c70784e 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -13,6 +13,7 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser +from .nnunet import get_nnunet_monai_predictor, get_nnunet_trainer, nnUNetMONAIModelWrapper from .properties import InferProperties, MetaProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py new file mode 100644 index 0000000000..f268d43c58 --- /dev/null +++ b/monai/bundle/nnunet.py @@ -0,0 +1,338 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os + +import numpy as np +import torch +from torch._dynamo import OptimizedModule +from torch.backends import cudnn + +from monai.data.meta_tensor import MetaTensor +from monai.utils import optional_import + +join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") +load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") + +__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "nnUNetMONAIModelWrapper"] + + +def get_nnunet_trainer( + dataset_name_or_id, + configuration, + fold, + trainer_class_name="nnUNetTrainer", + plans_identifier="nnUNetPlans", + pretrained_weights=None, + num_gpus=1, + use_compressed_data=False, + export_validation_probabilities=False, + continue_training=False, + only_run_validation=False, + disable_checkpointing=False, + val_with_best=False, + device=torch.device("cuda"), + pretrained_model=None, +): + """ + Get the nnUNet trainer instance based on the provided configuration. + The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, + optimizer, loss function, DataLoader, etc. + + ```python + from monai.apps import SupervisedTrainer + from monai.bundle.nnunet import get_nnunet_trainer + + dataset_name_or_id = 'Task101_PROSTATE' + fold = 0 + configuration = '3d_fullres' + nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) + + trainer = SupervisedTrainer( + device=nnunet_trainer.device, + max_epochs=nnunet_trainer.num_epochs, + train_data_loader=nnunet_trainer.dataloader_train, + network=nnunet_trainer.network, + optimizer=nnunet_trainer.optimizer, + loss_function=nnunet_trainer.loss_function, + epoch_length=nnunet_trainer.num_iterations_per_epoch, + + ``` + + Parameters + ---------- + dataset_name_or_id : Union[str, int] + The name or ID of the dataset to be used. + configuration : str + The configuration name for the training. + fold : Union[int, str] + The fold number or 'all' for cross-validation. + trainer_class_name : str, optional + The class name of the trainer to be used. Default is 'nnUNetTrainer'. + plans_identifier : str, optional + Identifier for the plans to be used. Default is 'nnUNetPlans'. + pretrained_weights : str, optional + Path to the pretrained weights file. + num_gpus : int, optional + Number of GPUs to be used. Default is 1. + use_compressed_data : bool, optional + Whether to use compressed data. Default is False. + export_validation_probabilities : bool, optional + Whether to export validation probabilities. Default is False. + continue_training : bool, optional + Whether to continue training from a checkpoint. Default is False. + only_run_validation : bool, optional + Whether to only run validation. Default is False. + disable_checkpointing : bool, optional + Whether to disable checkpointing. Default is False. + val_with_best : bool, optional + Whether to validate with the best model. Default is False. + device : torch.device, optional + The device to be used for training. Default is 'cuda'. + pretrained_model : str, optional + Path to the pretrained model file. + Returns + ------- + nnunet_trainer + The nnUNet trainer instance. + """ + # From nnUNet/nnunetv2/run/run_training.py#run_training + if isinstance(fold, str): + if fold != "all": + try: + fold = int(fold) + except ValueError as e: + print( + f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!' + ) + raise e + + if int(num_gpus) > 1: + ... # Disable for now + else: + from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint + + nnunet_trainer = get_trainer_from_args( + str(dataset_name_or_id), + configuration, + fold, + trainer_class_name, + plans_identifier, + use_compressed_data, + device=device, + ) + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy." + + maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) + nnunet_trainer.on_train_start() # Added to Initialize Trainer + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True + + if pretrained_model is not None: + state_dict = torch.load(pretrained_model) + if "network_weights" in state_dict: + nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) + return nnunet_trainer + + +class nnUNetMONAIModelWrapper(torch.nn.Module): + """ + A wrapper class for nnUNet model integration with MONAI framework. + The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference. + + Parameters + ---------- + predictor : object + The nnUNet predictor object used for inference. + model_folder : str + The folder path where the model and related files are stored. + model_name : str, optional + The name of the model file, by default "model.pt". + Attributes + ---------- + predictor : object + The predictor object used for inference. + network_weights : torch.nn.Module + The network weights of the model. + Methods + ------- + forward(x) + Perform forward pass and prediction on the input data. + Notes + ----- + This class integrates nnUNet model with MONAI framework by loading necessary configurations, + restoring network architecture, and setting up the predictor for inference. + """ + + def __init__(self, predictor, model_folder, model_name="model.pt"): + super().__init__() + self.predictor = predictor + + model_training_output_dir = model_folder + use_folds = "0" + + from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor + dataset_json = load_json(join(model_training_output_dir, "dataset.json")) + plans = load_json(join(model_training_output_dir, "plans.json")) + plans_manager = PlansManager(plans) + + if isinstance(use_folds, str): + use_folds = [use_folds] + + parameters = [] + for i, f in enumerate(use_folds): + f = int(f) if f != "all" else f + checkpoint = torch.load( + join(model_training_output_dir, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + ) + monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) + if i == 0: + trainer_name = checkpoint["trainer_name"] + configuration_name = checkpoint["init_args"]["configuration"] + inference_allowed_mirroring_axes = ( + checkpoint["inference_allowed_mirroring_axes"] + if "inference_allowed_mirroring_axes" in checkpoint.keys() + else None + ) + + parameters.append(monai_checkpoint["network_weights"]) + + configuration_manager = plans_manager.get_configuration(configuration_name) + # restore network + import nnunetv2 + from nnunetv2.utilities.find_class_by_name import recursive_find_python_class + from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels + + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + trainer_class = recursive_find_python_class( + join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, "nnunetv2.training.nnUNetTrainer" + ) + if trainer_class is None: + raise RuntimeError( + f"Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. " + f"Please place it there (in any .py file)!" + ) + network = trainer_class.build_network_architecture( + configuration_manager.network_arch_class_name, + configuration_manager.network_arch_init_kwargs, + configuration_manager.network_arch_init_kwargs_req_import, + num_input_channels, + plans_manager.get_label_manager(dataset_json).num_segmentation_heads, + enable_deep_supervision=False, + ) + + predictor.plans_manager = plans_manager + predictor.configuration_manager = configuration_manager + predictor.list_of_parameters = parameters + predictor.network = network + predictor.dataset_json = dataset_json + predictor.trainer_name = trainer_name + predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes + predictor.label_manager = plans_manager.get_label_manager(dataset_json) + if ( + ("nnUNet_compile" in os.environ.keys()) + and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) + and not isinstance(predictor.network, OptimizedModule) + ): + print("Using torch.compile") + predictor.network = torch.compile(self.network) + ## End Block + self.network_weights = self.predictor.network + + def forward(self, x): + if type(x) is tuple: + input_files = [img.meta["filename_or_obj"][0] for img in x] + else: + input_files = x.meta["filename_or_obj"] + if type(input_files) is str: + input_files = [input_files] + + output = self.predictor.predict_from_files( + [input_files], + None, + save_probabilities=False, + overwrite=True, + num_processes_preprocessing=2, + num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, + num_parts=1, + part_id=0, + ) + + out_tensors = [] + for out in output: + out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) + out_tensor = torch.cat(out_tensors, 0) + + if type(x) is tuple: + return MetaTensor(out_tensor, meta=x[0].meta) + else: + return MetaTensor(out_tensor, meta=x.meta) + + +def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): + """ + Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor. + The model folder should contain the following files, created during training: + - dataset.json: from the nnUNet results folder. + - plans.json: from the nnUNet results folder. + - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration + (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`). + - model.pt: The checkpoint file containing the model weights. + + The returned wrapper object can be used for inference with MONAI framework: + ```python + from monai.bundle.nnunet import get_nnunet_monai_predictor + + model_folder = 'path/to/monai_bundle/model' + model_name = 'model.pt' + wrapper = get_nnunet_monai_predictor(model_folder, model_name) + + # Perform inference + input_data = ... + output = wrapper(input_data) + + ``` + + Parameters + ---------- + model_folder : str + The folder where the model is stored. + model_name : str, optional + The name of the model file, by default "model.pt". + + Returns + ------- + nnUNetMONAIModelWrapper + A wrapper object that contains the nnUNetPredictor and the loaded model. + """ + + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=False, + device=torch.device("cuda", 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True, + ) + # initializes the network architecture, loads the checkpoint + wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name) + return wrapper diff --git a/tests/test_integration_nnunet_bundle.py b/tests/test_integration_nnunet_bundle.py new file mode 100644 index 0000000000..b177a697c5 --- /dev/null +++ b/tests/test_integration_nnunet_bundle.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np + +from monai.apps.nnunet import nnUNetV2Runner +from monai.bundle.nnunet import get_nnunet_trainer +from monai.bundle.config_parser import ConfigParser +from monai.data import create_test_image_3d +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick + +_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") +_, has_nnunet = optional_import("nnunetv2") + +sim_datalist: dict[str, list[dict]] = { + "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}], + "training": [ + {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"}, + {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"}, + {"fold": 1, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"}, + {"fold": 1, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"}, + {"fold": 2, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"}, + {"fold": 2, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"}, + {"fold": 3, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"}, + {"fold": 3, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"}, + {"fold": 4, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"}, + {"fold": 4, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"}, + ], +} + + +@skip_if_quick +@SkipIfBeforePyTorchVersion((1, 13, 0)) +@unittest.skipIf(not has_tb, "no tensorboard summary writer") +@unittest.skipIf(not has_nnunet, "no nnunetv2") +class TestnnUNetBundle(unittest.TestCase): + + def setUp(self) -> None: + self.test_dir = tempfile.TemporaryDirectory() + test_path = self.test_dir.name + + sim_dataroot = os.path.join(test_path, "dataroot") + if not os.path.isdir(sim_dataroot): + os.makedirs(sim_dataroot) + + # Generate a fake dataset + for d in sim_datalist["testing"] + sim_datalist["training"]: + im, seg = create_test_image_3d(24, 24, 24, rad_max=10, num_seg_classes=2) + nib_image = nib.Nifti1Image(im, affine=np.eye(4)) + image_fpath = os.path.join(sim_dataroot, d["image"]) + nib.save(nib_image, image_fpath) + + if "label" in d: + nib_image = nib.Nifti1Image(seg, affine=np.eye(4)) + label_fpath = os.path.join(sim_dataroot, d["label"]) + nib.save(nib_image, label_fpath) + + sim_json_datalist = os.path.join(sim_dataroot, "sim_input.json") + ConfigParser.export_config_file(sim_datalist, sim_json_datalist) + + data_src_cfg = os.path.join(test_path, "data_src_cfg.yaml") + data_src = {"modality": "CT", "datalist": sim_json_datalist, "dataroot": sim_dataroot} + + ConfigParser.export_config_file(data_src, data_src_cfg) + self.data_src_cfg = data_src_cfg + self.test_path = test_path + + @skip_if_no_cuda + def test_nnunetBundle(self) -> None: + runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch") + with skip_if_downloading_fails(): + runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False) + + #nnunet_trainer = get_nnunet_trainer(dataset_name_or_id=runner.dataset_name, fold=0,configuration="3d_fullres") + + #print("Max Epochs: ", nnunet_trainer.num_epochs) + #print("Num Iterations: ", nnunet_trainer.num_iterations_per_epoch) + #print("Train Batch dims: ", next(nnunet_trainer.dataloader_train.generator)['data'].shape) + #print("Val Batch dims: ", next(nnunet_trainer.dataloader_val.generator)['data'].shape) + #print("Network: ", nnunet_trainer.network) + #print("Optimizer: ", nnunet_trainer.optimizer) + #print("Loss Function: ", nnunet_trainer.loss) + #print("LR Scheduler: ", nnunet_trainer.lr_scheduler) + #print("Device: ", nnunet_trainer.device) + + runner.train("3d_fullres", 1) + + + + + def tearDown(self) -> None: + self.test_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() From 74aaf73d1e8337c66e0218b50969b6146f2e5b36 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 5 Feb 2025 12:27:46 +0000 Subject: [PATCH 02/67] Implement nnUNet model conversion to MONAI bundle format and enhance integration tests --- monai/bundle/nnunet.py | 84 ++++++++++++++++++++++--- tests/test_integration_nnunet_bundle.py | 66 ++++++++++++++----- 2 files changed, 125 insertions(+), 25 deletions(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index f268d43c58..6bcc57130c 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -17,6 +17,8 @@ from torch._dynamo import OptimizedModule from torch.backends import cudnn +from pathlib import Path +import shutil from monai.data.meta_tensor import MetaTensor from monai.utils import optional_import @@ -255,14 +257,15 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): self.network_weights = self.predictor.network def forward(self, x): - if type(x) is tuple: + if type(x) is tuple: # if batch is decollated (list of tensors) input_files = [img.meta["filename_or_obj"][0] for img in x] - else: + else: # if batch is collated input_files = x.meta["filename_or_obj"] - if type(input_files) is str: - input_files = [input_files] - - output = self.predictor.predict_from_files( + if type(input_files) is str: + input_files = [input_files] + + # input_files should be a list of file paths, one per modality + prediction_output = self.predictor.predict_from_files( [input_files], None, save_probabilities=False, @@ -273,11 +276,12 @@ def forward(self, x): num_parts=1, part_id=0, ) - + # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax + out_tensors = [] - for out in output: + for out in prediction_output: # Add batch and channel dimensions out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) - out_tensor = torch.cat(out_tensors, 0) + out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension if type(x) is tuple: return MetaTensor(out_tensor, meta=x[0].meta) @@ -336,3 +340,65 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): # initializes the network architecture, loads the checkpoint wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name) return wrapper + + +def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): + """ + Convert nnUNet model checkpoints and configuration to MONAI bundle format. + + Parameters + ---------- + nnunet_config : dict + Configuration dictionary for nnUNet, containing keys such as 'dataset_name_or_id', 'nnunet_configuration', + 'nnunet_trainer', and 'nnunet_plans'. + bundle_root_folder : str + Root folder where the MONAI bundle will be saved. + fold : int, optional + Fold number of the nnUNet model to be converted, by default 0. + + Returns + ------- + None + """ + + nnunet_trainer = "nnUNetTrainer" + nnunet_plans = "nnUNetPlans" + nnunet_configuration = "3d_fullres" + + if "nnunet_trainer" in nnunet_config: + nnunet_trainer = nnunet_config["nnunet_trainer"] + + if "nnunet_plans" in nnunet_config: + nnunet_plans = nnunet_config["nnunet_plans"] + + if "nnunet_configuration" in nnunet_config: + nnunet_configuration = nnunet_config["nnunet_configuration"] + + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + + dataset_name = maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]) + nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath( + dataset_name, + f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}") + + nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}","checkpoint_final.pth")) + nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}","checkpoint_best.pth")) + + nnunet_checkpoint = {} + nnunet_checkpoint['inference_allowed_mirroring_axes'] = nnunet_checkpoint_final['inference_allowed_mirroring_axes'] + nnunet_checkpoint['init_args'] = nnunet_checkpoint_final['init_args'] + nnunet_checkpoint['trainer_name'] = nnunet_checkpoint_final['trainer_name'] + + torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models","nnunet_checkpoint.pth")) + + monai_last_checkpoint = {} + monai_last_checkpoint['network_weights'] = nnunet_checkpoint_final['network_weights'] + torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models","model.pt")) + + monai_best_checkpoint = {} + monai_best_checkpoint['network_weights'] = nnunet_checkpoint_best['network_weights'] + torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models","best_model.pt")) + + shutil.copy(Path(nnunet_model_folder).joinpath("plans.json"),Path(bundle_root_folder).joinpath("models","plans.json")) + shutil.copy(Path(nnunet_model_folder).joinpath("dataset.json"),Path(bundle_root_folder).joinpath("models","dataset.json")) \ No newline at end of file diff --git a/tests/test_integration_nnunet_bundle.py b/tests/test_integration_nnunet_bundle.py index b177a697c5..c73d8853d2 100644 --- a/tests/test_integration_nnunet_bundle.py +++ b/tests/test_integration_nnunet_bundle.py @@ -19,7 +19,10 @@ import numpy as np from monai.apps.nnunet import nnUNetV2Runner -from monai.bundle.nnunet import get_nnunet_trainer +from monai.bundle.nnunet import get_nnunet_trainer, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor +from monai.transforms import LoadImaged, SaveImaged, Transposed, EnsureChannelFirstd, Compose, Decollated +from monai.data import DataLoader, Dataset +from pathlib import Path from monai.bundle.config_parser import ConfigParser from monai.data import create_test_image_3d from monai.utils import optional_import @@ -59,6 +62,7 @@ def setUp(self) -> None: if not os.path.isdir(sim_dataroot): os.makedirs(sim_dataroot) + self.sim_dataroot = sim_dataroot # Generate a fake dataset for d in sim_datalist["testing"] + sim_datalist["training"]: im, seg = create_test_image_3d(24, 24, 24, rad_max=10, num_seg_classes=2) @@ -82,27 +86,57 @@ def setUp(self) -> None: self.test_path = test_path @skip_if_no_cuda - def test_nnunetBundle(self) -> None: + def test_nnunetBundle_get_trainer(self) -> None: runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch") with skip_if_downloading_fails(): runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False) - #nnunet_trainer = get_nnunet_trainer(dataset_name_or_id=runner.dataset_name, fold=0,configuration="3d_fullres") + nnunet_trainer = get_nnunet_trainer(dataset_name_or_id=runner.dataset_name, fold=0,configuration="3d_fullres") - #print("Max Epochs: ", nnunet_trainer.num_epochs) - #print("Num Iterations: ", nnunet_trainer.num_iterations_per_epoch) - #print("Train Batch dims: ", next(nnunet_trainer.dataloader_train.generator)['data'].shape) - #print("Val Batch dims: ", next(nnunet_trainer.dataloader_val.generator)['data'].shape) - #print("Network: ", nnunet_trainer.network) - #print("Optimizer: ", nnunet_trainer.optimizer) - #print("Loss Function: ", nnunet_trainer.loss) - #print("LR Scheduler: ", nnunet_trainer.lr_scheduler) - #print("Device: ", nnunet_trainer.device) - - runner.train("3d_fullres", 1) - - + print("Max Epochs: ", nnunet_trainer.num_epochs) + print("Num Iterations: ", nnunet_trainer.num_iterations_per_epoch) + print("Train Batch dims: ", next(nnunet_trainer.dataloader_train.generator)['data'].shape) + print("Val Batch dims: ", next(nnunet_trainer.dataloader_val.generator)['data'].shape) + print("Network: ", nnunet_trainer.network) + print("Optimizer: ", nnunet_trainer.optimizer) + print("Loss Function: ", nnunet_trainer.loss) + print("LR Scheduler: ", nnunet_trainer.lr_scheduler) + print("Device: ", nnunet_trainer.device) + runner.train("3d_fullres") + @skip_if_no_cuda + def test_nnunetBundle_convert_bundle(self) -> None: + + nnunet_config = { + "dataset_name_or_id": "001", + "nnunet_trainer": "nnUNetTrainer_1epoch", + } + self.bundle_root = os.path.join("bundle_root") + + Path(self.bundle_root).joinpath("models").mkdir(parents=True, exist_ok=True) + convert_nnunet_to_monai_bundle(nnunet_config, self.bundle_root, 0) + + + def test_nnunetBundle_predict_from_bundle(self) -> None: + data_transforms = Compose([ + LoadImaged(keys="image"), + EnsureChannelFirstd(keys="image"), + ]) + dataset = Dataset(data=[{"image": os.path.join(self.test_path, "dataroot", "val_001.fake.nii.gz")}], + transform=data_transforms) + data_loader = DataLoader(dataset, batch_size=1) + input = next(iter(data_loader)) + + predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models")) + pred_batch = predictor(input["image"]) + Path(self.sim_dataroot).joinpath("predictions").mkdir(parents=True, exist_ok=True) + + post_processing_transforms = Compose([ + Decollated(keys=None, detach=True), + Transposed(keys="pred", indices=[0, 3, 2, 1]), + SaveImaged(keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred"), + ]) + post_processing_transforms({"pred": pred_batch}) def tearDown(self) -> None: self.test_dir.cleanup() From b61e4e19a8c402d8c6e9ea9e716657b8a97fa643 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 5 Feb 2025 12:58:57 +0000 Subject: [PATCH 03/67] Refactor nnUNet bundle integration tests for clarity and remove redundant method --- tests/test_integration_nnunet_bundle.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_integration_nnunet_bundle.py b/tests/test_integration_nnunet_bundle.py index c73d8853d2..2f4ffb4cfc 100644 --- a/tests/test_integration_nnunet_bundle.py +++ b/tests/test_integration_nnunet_bundle.py @@ -86,7 +86,7 @@ def setUp(self) -> None: self.test_path = test_path @skip_if_no_cuda - def test_nnunetBundle_get_trainer(self) -> None: + def test_nnunetBundle(self) -> None: runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch") with skip_if_downloading_fails(): runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False) @@ -103,10 +103,7 @@ def test_nnunetBundle_get_trainer(self) -> None: print("LR Scheduler: ", nnunet_trainer.lr_scheduler) print("Device: ", nnunet_trainer.device) runner.train("3d_fullres") - @skip_if_no_cuda - def test_nnunetBundle_convert_bundle(self) -> None: - nnunet_config = { "dataset_name_or_id": "001", "nnunet_trainer": "nnUNetTrainer_1epoch", @@ -116,8 +113,6 @@ def test_nnunetBundle_convert_bundle(self) -> None: Path(self.bundle_root).joinpath("models").mkdir(parents=True, exist_ok=True) convert_nnunet_to_monai_bundle(nnunet_config, self.bundle_root, 0) - - def test_nnunetBundle_predict_from_bundle(self) -> None: data_transforms = Compose([ LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), From 8e4a66ce70738ba70e049c116dabf5eadf2e6305 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 5 Feb 2025 15:34:50 +0000 Subject: [PATCH 04/67] Code reformatting --- monai/bundle/__init__.py | 2 +- monai/bundle/nnunet.py | 65 ++++++++++--------- tests/test_integration_nnunet_bundle.py | 83 ++++++++++++------------- 3 files changed, 74 insertions(+), 76 deletions(-) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 5d6c70784e..305bf9eb6a 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -13,7 +13,7 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser -from .nnunet import get_nnunet_monai_predictor, get_nnunet_trainer, nnUNetMONAIModelWrapper +from .nnunet import ModelnnUNetWrapper, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer from .properties import InferProperties, MetaProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index 6bcc57130c..cb6107e82a 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -11,21 +11,21 @@ from __future__ import annotations import os +import shutil +from pathlib import Path import numpy as np import torch from torch._dynamo import OptimizedModule from torch.backends import cudnn -from pathlib import Path -import shutil from monai.data.meta_tensor import MetaTensor from monai.utils import optional_import join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") -__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "nnUNetMONAIModelWrapper"] +__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "ModelnnUNetWrapper"] def get_nnunet_trainer( @@ -42,7 +42,7 @@ def get_nnunet_trainer( only_run_validation=False, disable_checkpointing=False, val_with_best=False, - device=torch.device("cuda"), + device="cuda", pretrained_model=None, ): """ @@ -98,7 +98,7 @@ def get_nnunet_trainer( Whether to disable checkpointing. Default is False. val_with_best : bool, optional Whether to validate with the best model. Default is False. - device : torch.device, optional + device : str, optional The device to be used for training. Default is 'cuda'. pretrained_model : str, optional Path to the pretrained model file. @@ -130,7 +130,7 @@ def get_nnunet_trainer( trainer_class_name, plans_identifier, use_compressed_data, - device=device, + device=torch.device(device), ) if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing @@ -150,7 +150,7 @@ def get_nnunet_trainer( return nnunet_trainer -class nnUNetMONAIModelWrapper(torch.nn.Module): +class ModelnnUNetWrapper(torch.nn.Module): """ A wrapper class for nnUNet model integration with MONAI framework. The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference. @@ -188,7 +188,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): from nnunetv2.utilities.plans_handling.plans_handler import PlansManager - ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor + # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor dataset_json = load_json(join(model_training_output_dir, "dataset.json")) plans = load_json(join(model_training_output_dir, "plans.json")) plans_manager = PlansManager(plans) @@ -253,17 +253,17 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): ): print("Using torch.compile") predictor.network = torch.compile(self.network) - ## End Block + # End Block self.network_weights = self.predictor.network def forward(self, x): if type(x) is tuple: # if batch is decollated (list of tensors) input_files = [img.meta["filename_or_obj"][0] for img in x] - else: # if batch is collated + else: # if batch is collated input_files = x.meta["filename_or_obj"] if type(input_files) is str: input_files = [input_files] - + # input_files should be a list of file paths, one per modality prediction_output = self.predictor.predict_from_files( [input_files], @@ -277,11 +277,11 @@ def forward(self, x): part_id=0, ) # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax - + out_tensors = [] - for out in prediction_output: # Add batch and channel dimensions + for out in prediction_output: # Add batch and channel dimensions out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) - out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension + out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension if type(x) is tuple: return MetaTensor(out_tensor, meta=x[0].meta) @@ -338,7 +338,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): allow_tqdm=True, ) # initializes the network architecture, loads the checkpoint - wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name) + wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name) return wrapper @@ -376,29 +376,32 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name - dataset_name = maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]) nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath( - dataset_name, - f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}") - - nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}","checkpoint_final.pth")) - nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}","checkpoint_best.pth")) + dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" + ) + + nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) + nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) nnunet_checkpoint = {} - nnunet_checkpoint['inference_allowed_mirroring_axes'] = nnunet_checkpoint_final['inference_allowed_mirroring_axes'] - nnunet_checkpoint['init_args'] = nnunet_checkpoint_final['init_args'] - nnunet_checkpoint['trainer_name'] = nnunet_checkpoint_final['trainer_name'] + nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] + nnunet_checkpoint["init_args"] = nnunet_checkpoint_final["init_args"] + nnunet_checkpoint["trainer_name"] = nnunet_checkpoint_final["trainer_name"] - torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models","nnunet_checkpoint.pth")) + torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth")) monai_last_checkpoint = {} - monai_last_checkpoint['network_weights'] = nnunet_checkpoint_final['network_weights'] - torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models","model.pt")) + monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"] + torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", "model.pt")) monai_best_checkpoint = {} - monai_best_checkpoint['network_weights'] = nnunet_checkpoint_best['network_weights'] - torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models","best_model.pt")) + monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"] + torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", "best_model.pt")) - shutil.copy(Path(nnunet_model_folder).joinpath("plans.json"),Path(bundle_root_folder).joinpath("models","plans.json")) - shutil.copy(Path(nnunet_model_folder).joinpath("dataset.json"),Path(bundle_root_folder).joinpath("models","dataset.json")) \ No newline at end of file + shutil.copy( + Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json") + ) + shutil.copy( + Path(nnunet_model_folder).joinpath("dataset.json"), Path(bundle_root_folder).joinpath("models", "dataset.json") + ) diff --git a/tests/test_integration_nnunet_bundle.py b/tests/test_integration_nnunet_bundle.py index c73d8853d2..cb8c2e3d54 100644 --- a/tests/test_integration_nnunet_bundle.py +++ b/tests/test_integration_nnunet_bundle.py @@ -14,17 +14,16 @@ import os import tempfile import unittest +from pathlib import Path import nibabel as nib import numpy as np from monai.apps.nnunet import nnUNetV2Runner -from monai.bundle.nnunet import get_nnunet_trainer, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor -from monai.transforms import LoadImaged, SaveImaged, Transposed, EnsureChannelFirstd, Compose, Decollated -from monai.data import DataLoader, Dataset -from pathlib import Path from monai.bundle.config_parser import ConfigParser -from monai.data import create_test_image_3d +from monai.bundle.nnunet import convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer +from monai.data import DataLoader, Dataset, create_test_image_3d +from monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged, Transposed from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick @@ -86,57 +85,53 @@ def setUp(self) -> None: self.test_path = test_path @skip_if_no_cuda - def test_nnunetBundle_get_trainer(self) -> None: - runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch") + def test_nnunet_bundle(self) -> None: + runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch",work_dir=self.test_path) with skip_if_downloading_fails(): runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False) - nnunet_trainer = get_nnunet_trainer(dataset_name_or_id=runner.dataset_name, fold=0,configuration="3d_fullres") - + nnunet_trainer = get_nnunet_trainer( + dataset_name_or_id=runner.dataset_name, fold=0, configuration="3d_fullres" + ) + print("Max Epochs: ", nnunet_trainer.num_epochs) print("Num Iterations: ", nnunet_trainer.num_iterations_per_epoch) - print("Train Batch dims: ", next(nnunet_trainer.dataloader_train.generator)['data'].shape) - print("Val Batch dims: ", next(nnunet_trainer.dataloader_val.generator)['data'].shape) + print("Train Batch dims: ", next(nnunet_trainer.dataloader_train.generator)["data"].shape) + print("Val Batch dims: ", next(nnunet_trainer.dataloader_val.generator)["data"].shape) print("Network: ", nnunet_trainer.network) print("Optimizer: ", nnunet_trainer.optimizer) print("Loss Function: ", nnunet_trainer.loss) print("LR Scheduler: ", nnunet_trainer.lr_scheduler) print("Device: ", nnunet_trainer.device) - runner.train("3d_fullres") - @skip_if_no_cuda - def test_nnunetBundle_convert_bundle(self) -> None: - - - nnunet_config = { - "dataset_name_or_id": "001", - "nnunet_trainer": "nnUNetTrainer_1epoch", - } - self.bundle_root = os.path.join("bundle_root") - - Path(self.bundle_root).joinpath("models").mkdir(parents=True, exist_ok=True) - convert_nnunet_to_monai_bundle(nnunet_config, self.bundle_root, 0) - - - def test_nnunetBundle_predict_from_bundle(self) -> None: - data_transforms = Compose([ - LoadImaged(keys="image"), - EnsureChannelFirstd(keys="image"), - ]) - dataset = Dataset(data=[{"image": os.path.join(self.test_path, "dataroot", "val_001.fake.nii.gz")}], - transform=data_transforms) - data_loader = DataLoader(dataset, batch_size=1) - input = next(iter(data_loader)) - - predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models")) - pred_batch = predictor(input["image"]) - Path(self.sim_dataroot).joinpath("predictions").mkdir(parents=True, exist_ok=True) - - post_processing_transforms = Compose([ + runner.train_single_model("3d_fullres", fold="0") + + nnunet_config = {"dataset_name_or_id": "001", "nnunet_trainer": "nnUNetTrainer_1epoch"} + self.bundle_root = os.path.join("bundle_root") + + Path(self.bundle_root).joinpath("models").mkdir(parents=True, exist_ok=True) + convert_nnunet_to_monai_bundle(nnunet_config, self.bundle_root, 0) + + data_transforms = Compose([LoadImaged(keys="image"), EnsureChannelFirstd(keys="image")]) + dataset = Dataset( + data=[{"image": os.path.join(self.test_path, "dataroot", "val_001.fake.nii.gz")}], transform=data_transforms + ) + data_loader = DataLoader(dataset, batch_size=1) + input = next(iter(data_loader)) + + predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models")) + pred_batch = predictor(input["image"]) + Path(self.sim_dataroot).joinpath("predictions").mkdir(parents=True, exist_ok=True) + + post_processing_transforms = Compose( + [ Decollated(keys=None, detach=True), Transposed(keys="pred", indices=[0, 3, 2, 1]), - SaveImaged(keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred"), - ]) - post_processing_transforms({"pred": pred_batch}) + SaveImaged( + keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred" + ), + ] + ) + post_processing_transforms({"pred": pred_batch}) def tearDown(self) -> None: self.test_dir.cleanup() From 49dbb5dde83d31325c6affbfed8494c1de06d934 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 5 Feb 2025 15:47:55 +0000 Subject: [PATCH 05/67] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: 13bdec5130709456abb3721450961801ab3bb4bc I, simben , hereby add my Signed-off-by to this commit: 74aaf73d1e8337c66e0218b50969b6146f2e5b36 I, simben , hereby add my Signed-off-by to this commit: 8e4a66ce70738ba70e049c116dabf5eadf2e6305 Signed-off-by: simben --- monai/bundle/nnunet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index cb6107e82a..d5411dd531 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -16,7 +16,6 @@ import numpy as np import torch -from torch._dynamo import OptimizedModule from torch.backends import cudnn from monai.data.meta_tensor import MetaTensor @@ -249,7 +248,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): if ( ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) - and not isinstance(predictor.network, OptimizedModule) + #and not isinstance(predictor.network, OptimizedModule) ): print("Using torch.compile") predictor.network = torch.compile(self.network) From 5a04fe0b174fa5a519e4ec97196d4fa49e56a511 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 5 Feb 2025 17:00:01 +0000 Subject: [PATCH 06/67] nibabel importing moved to setUp --- monai/bundle/nnunet.py | 2 +- tests/test_integration_nnunet_bundle.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index d5411dd531..4a110b9402 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -248,7 +248,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): if ( ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) - #and not isinstance(predictor.network, OptimizedModule) + # and not isinstance(predictor.network, OptimizedModule) ): print("Using torch.compile") predictor.network = torch.compile(self.network) diff --git a/tests/test_integration_nnunet_bundle.py b/tests/test_integration_nnunet_bundle.py index cb8c2e3d54..beb6672f10 100644 --- a/tests/test_integration_nnunet_bundle.py +++ b/tests/test_integration_nnunet_bundle.py @@ -16,7 +16,6 @@ import unittest from pathlib import Path -import nibabel as nib import numpy as np from monai.apps.nnunet import nnUNetV2Runner @@ -54,6 +53,8 @@ class TestnnUNetBundle(unittest.TestCase): def setUp(self) -> None: + import nibabel as nib + self.test_dir = tempfile.TemporaryDirectory() test_path = self.test_dir.name @@ -86,7 +87,9 @@ def setUp(self) -> None: @skip_if_no_cuda def test_nnunet_bundle(self) -> None: - runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch",work_dir=self.test_path) + runner = nnUNetV2Runner( + input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch", work_dir=self.test_path + ) with skip_if_downloading_fails(): runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False) From 24643b8f436fe0360b2e6eb3cdf328efa3390963 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 5 Feb 2025 17:03:32 +0000 Subject: [PATCH 07/67] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: 5a04fe0b174fa5a519e4ec97196d4fa49e56a511 Signed-off-by: simben --- tests/test_integration_nnunet_bundle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_integration_nnunet_bundle.py b/tests/test_integration_nnunet_bundle.py index beb6672f10..c10de80317 100644 --- a/tests/test_integration_nnunet_bundle.py +++ b/tests/test_integration_nnunet_bundle.py @@ -53,6 +53,7 @@ class TestnnUNetBundle(unittest.TestCase): def setUp(self) -> None: + import nibabel as nib self.test_dir = tempfile.TemporaryDirectory() From 253dab178cea44d015fb2b9cf2a3363925a69b1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Feb 2025 17:05:05 +0000 Subject: [PATCH 08/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_integration_nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_integration_nnunet_bundle.py b/tests/test_integration_nnunet_bundle.py index c10de80317..553708b61b 100644 --- a/tests/test_integration_nnunet_bundle.py +++ b/tests/test_integration_nnunet_bundle.py @@ -53,7 +53,7 @@ class TestnnUNetBundle(unittest.TestCase): def setUp(self) -> None: - + import nibabel as nib self.test_dir = tempfile.TemporaryDirectory() From 43c694b5403aa07d408130ec1571b0cc3e9ee8c2 Mon Sep 17 00:00:00 2001 From: simben Date: Thu, 6 Feb 2025 11:09:29 +0000 Subject: [PATCH 09/67] Add nnUNet Bundle documentation and related functions to bundle.rst --- docs/source/bundle.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 4e3a32b6fe..49297e1a3d 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -34,6 +34,13 @@ Model Bundle :members: :special-members: + +`nnUNet Bundle` +--------------- +.. autoclass:: ModelnnUNetWrapper + :members: + :special-members: + `Scripts` --------- .. autofunction:: ckpt_export @@ -50,3 +57,6 @@ Model Bundle .. autofunction:: init_bundle .. autofunction:: push_to_hf_hub .. autofunction:: update_kwargs +.. autofunction:: get_nnunet_trainer +.. autofunction:: get_nnunet_monai_predictor +.. autofunction:: convert_nnunet_to_monai_bundle From 569df7bc998eb6fa1f61cf0ba313ea8153cf4e25 Mon Sep 17 00:00:00 2001 From: simben Date: Thu, 6 Feb 2025 11:09:52 +0000 Subject: [PATCH 10/67] Refactor nnUNet documentation and examples for clarity; update fold parameter type in tests --- monai/bundle/nnunet.py | 80 +++++++++++++------------ tests/test_integration_nnunet_bundle.py | 2 +- 2 files changed, 43 insertions(+), 39 deletions(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index 4a110b9402..8b4bccb87a 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -49,25 +49,25 @@ def get_nnunet_trainer( The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, optimizer, loss function, DataLoader, etc. - ```python - from monai.apps import SupervisedTrainer - from monai.bundle.nnunet import get_nnunet_trainer - - dataset_name_or_id = 'Task101_PROSTATE' - fold = 0 - configuration = '3d_fullres' - nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) - - trainer = SupervisedTrainer( - device=nnunet_trainer.device, - max_epochs=nnunet_trainer.num_epochs, - train_data_loader=nnunet_trainer.dataloader_train, - network=nnunet_trainer.network, - optimizer=nnunet_trainer.optimizer, - loss_function=nnunet_trainer.loss_function, - epoch_length=nnunet_trainer.num_iterations_per_epoch, - - ``` + Example:: + + from monai.apps import SupervisedTrainer + from monai.bundle.nnunet import get_nnunet_trainer + + dataset_name_or_id = 'Task101_PROSTATE' + fold = 0 + configuration = '3d_fullres' + nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) + + trainer = SupervisedTrainer( + device=nnunet_trainer.device, + max_epochs=nnunet_trainer.num_epochs, + train_data_loader=nnunet_trainer.dataloader_train, + network=nnunet_trainer.network, + optimizer=nnunet_trainer.optimizer, + loss_function=nnunet_trainer.loss_function, + epoch_length=nnunet_trainer.num_iterations_per_epoch, + ) Parameters ---------- @@ -162,16 +162,19 @@ class ModelnnUNetWrapper(torch.nn.Module): The folder path where the model and related files are stored. model_name : str, optional The name of the model file, by default "model.pt". + Attributes ---------- - predictor : object - The predictor object used for inference. + predictor : nnUNetPredictor + The nnUNet predictor object used for inference. network_weights : torch.nn.Module The network weights of the model. + Methods ------- forward(x) Perform forward pass and prediction on the input data. + Notes ----- This class integrates nnUNet model with MONAI framework by loading necessary configurations, @@ -183,7 +186,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): self.predictor = predictor model_training_output_dir = model_folder - use_folds = "0" + use_folds = ["0"] from nnunetv2.utilities.plans_handling.plans_handler import PlansManager @@ -290,27 +293,28 @@ def forward(self, x): def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): """ - Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor. + Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. The model folder should contain the following files, created during training: - - dataset.json: from the nnUNet results folder. - - plans.json: from the nnUNet results folder. - - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration - (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`). - - model.pt: The checkpoint file containing the model weights. - + + - dataset.json: from the nnUNet results folder + - plans.json: from the nnUNet results folder + - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`) + - model.pt: The checkpoint file containing the model weights. + The returned wrapper object can be used for inference with MONAI framework: - ```python - from monai.bundle.nnunet import get_nnunet_monai_predictor + + Example:: + + from monai.bundle.nnunet import get_nnunet_monai_predictor - model_folder = 'path/to/monai_bundle/model' - model_name = 'model.pt' - wrapper = get_nnunet_monai_predictor(model_folder, model_name) + model_folder = 'path/to/monai_bundle/model' + model_name = 'model.pt' + wrapper = get_nnunet_monai_predictor(model_folder, model_name) - # Perform inference - input_data = ... - output = wrapper(input_data) + # Perform inference + input_data = ... + output = wrapper(input_data) - ``` Parameters ---------- diff --git a/tests/test_integration_nnunet_bundle.py b/tests/test_integration_nnunet_bundle.py index 553708b61b..9a3f362eab 100644 --- a/tests/test_integration_nnunet_bundle.py +++ b/tests/test_integration_nnunet_bundle.py @@ -107,7 +107,7 @@ def test_nnunet_bundle(self) -> None: print("Loss Function: ", nnunet_trainer.loss) print("LR Scheduler: ", nnunet_trainer.lr_scheduler) print("Device: ", nnunet_trainer.device) - runner.train_single_model("3d_fullres", fold="0") + runner.train_single_model("3d_fullres", fold=0) nnunet_config = {"dataset_name_or_id": "001", "nnunet_trainer": "nnUNetTrainer_1epoch"} self.bundle_root = os.path.join("bundle_root") From fcf5ac0b000dd139b179d640b3719d01d81d6b3f Mon Sep 17 00:00:00 2001 From: simben Date: Thu, 6 Feb 2025 12:53:29 +0000 Subject: [PATCH 11/67] Clean up whitespace in nnunet.py and add test for nnunet bundle integration --- monai/bundle/nnunet.py | 18 +++++++++--------- tests/min_tests.py | 1 + 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index 8b4bccb87a..c6bb2aa971 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -50,7 +50,7 @@ def get_nnunet_trainer( optimizer, loss function, DataLoader, etc. Example:: - + from monai.apps import SupervisedTrainer from monai.bundle.nnunet import get_nnunet_trainer @@ -162,19 +162,19 @@ class ModelnnUNetWrapper(torch.nn.Module): The folder path where the model and related files are stored. model_name : str, optional The name of the model file, by default "model.pt". - + Attributes ---------- predictor : nnUNetPredictor The nnUNet predictor object used for inference. network_weights : torch.nn.Module The network weights of the model. - + Methods ------- forward(x) Perform forward pass and prediction on the input data. - + Notes ----- This class integrates nnUNet model with MONAI framework by loading necessary configurations, @@ -295,16 +295,16 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): """ Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. The model folder should contain the following files, created during training: - + - dataset.json: from the nnUNet results folder - plans.json: from the nnUNet results folder - - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`) + - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration - model.pt: The checkpoint file containing the model weights. - + The returned wrapper object can be used for inference with MONAI framework: - + Example:: - + from monai.bundle.nnunet import get_nnunet_monai_predictor model_folder = 'path/to/monai_bundle/model' diff --git a/tests/min_tests.py b/tests/min_tests.py index f39d3f9843..837294a495 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -126,6 +126,7 @@ def run_testsuit(): "test_integration_bundle_run", "test_integration_autorunner", "test_integration_nnunetv2_runner", + "test_integration_nnunet_bundle", "test_invert", "test_invertd", "test_iterable_dataset", From c846b6d4ab157885b0b7b10c7e04a876478e9a23 Mon Sep 17 00:00:00 2001 From: simben Date: Thu, 6 Feb 2025 13:11:15 +0000 Subject: [PATCH 12/67] Fix type conversion for folds and improve input_files type checking in ModelnnUNetWrapper --- monai/bundle/nnunet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index c6bb2aa971..b152532399 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -200,7 +200,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): parameters = [] for i, f in enumerate(use_folds): - f = int(f) if f != "all" else f + f = str(f) if f != "all" else f checkpoint = torch.load( join(model_training_output_dir, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") ) @@ -263,7 +263,7 @@ def forward(self, x): input_files = [img.meta["filename_or_obj"][0] for img in x] else: # if batch is collated input_files = x.meta["filename_or_obj"] - if type(input_files) is str: + if isinstance(input_files, str): input_files = [input_files] # input_files should be a list of file paths, one per modality From 2da5ca97156d95ac130cd333f67694c2db97482d Mon Sep 17 00:00:00 2001 From: simben Date: Thu, 6 Feb 2025 13:26:56 +0000 Subject: [PATCH 13/67] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: 43c694b5403aa07d408130ec1571b0cc3e9ee8c2 I, simben , hereby add my Signed-off-by to this commit: 569df7bc998eb6fa1f61cf0ba313ea8153cf4e25 I, simben , hereby add my Signed-off-by to this commit: fcf5ac0b000dd139b179d640b3719d01d81d6b3f I, simben , hereby add my Signed-off-by to this commit: c846b6d4ab157885b0b7b10c7e04a876478e9a23 Signed-off-by: simben --- monai/bundle/nnunet.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index b152532399..6bbb9b2333 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -259,6 +259,28 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): self.network_weights = self.predictor.network def forward(self, x): + """ + Forward pass for the nnUNet model. + + :no-index: + + Args: + x (Union[torch.Tensor, Tuple[MetaTensor]]): Input tensor or a tuple of MetaTensors. If the input is a tuple, + it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch. + + Returns: + MetaTensor: The output tensor with the same metadata as the input. + + Raises: + TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors. + + Notes: + - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple. + - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor. + - The filenames are used to generate predictions using the nnUNet predictor. + - The predictions are converted to torch tensors, with added batch and channel dimensions. + - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata as the input. + """ if type(x) is tuple: # if batch is decollated (list of tensors) input_files = [img.meta["filename_or_obj"][0] for img in x] else: # if batch is collated From 48d53a487bb809e39ecc0b7875b63c3736cd6d42 Mon Sep 17 00:00:00 2001 From: simben Date: Thu, 6 Feb 2025 13:29:27 +0000 Subject: [PATCH 14/67] Fix documentation for output tensor return in ModelnnUNetWrapper Signed-off-by: simben --- monai/bundle/nnunet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index 6bbb9b2333..5c14ec9eb2 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -279,7 +279,7 @@ def forward(self, x): - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor. - The filenames are used to generate predictions using the nnUNet predictor. - The predictions are converted to torch tensors, with added batch and channel dimensions. - - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata as the input. + - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. """ if type(x) is tuple: # if batch is decollated (list of tensors) input_files = [img.meta["filename_or_obj"][0] for img in x] From 230cb9b841c3bb27cd9acced4011a3e455326a40 Mon Sep 17 00:00:00 2001 From: simben Date: Thu, 6 Feb 2025 14:01:03 +0000 Subject: [PATCH 15/67] Remove outdated method documentation for forward pass in ModelnnUNetWrapper Signed-off-by: simben --- monai/bundle/nnunet.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index 5c14ec9eb2..02dae801b2 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -170,11 +170,6 @@ class ModelnnUNetWrapper(torch.nn.Module): network_weights : torch.nn.Module The network weights of the model. - Methods - ------- - forward(x) - Perform forward pass and prediction on the input data. - Notes ----- This class integrates nnUNet model with MONAI framework by loading necessary configurations, From ea8028fd5fd39593cc9be21965f0c31b5e6a815e Mon Sep 17 00:00:00 2001 From: simben Date: Tue, 18 Feb 2025 07:36:12 +0000 Subject: [PATCH 16/67] Add integration tests for nnUNet bundle functionality --- tests/{ => integration}/test_integration_nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/{ => integration}/test_integration_nnunet_bundle.py (98%) diff --git a/tests/test_integration_nnunet_bundle.py b/tests/integration/test_integration_nnunet_bundle.py similarity index 98% rename from tests/test_integration_nnunet_bundle.py rename to tests/integration/test_integration_nnunet_bundle.py index 9a3f362eab..4e04f3f5cf 100644 --- a/tests/test_integration_nnunet_bundle.py +++ b/tests/integration/test_integration_nnunet_bundle.py @@ -24,7 +24,7 @@ from monai.data import DataLoader, Dataset, create_test_image_3d from monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged, Transposed from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick +from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick _, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") _, has_nnunet = optional_import("nnunetv2") From 9dc6532a18bfcf264971364b8b6453c40ddc9134 Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 24 Feb 2025 16:36:02 +0000 Subject: [PATCH 17/67] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: b61e4e19a8c402d8c6e9ea9e716657b8a97fa643 I, simben , hereby add my Signed-off-by to this commit: ea8028fd5fd39593cc9be21965f0c31b5e6a815e Signed-off-by: simben --- monai/bundle/nnunet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py index 02dae801b2..4eca036c17 100644 --- a/monai/bundle/nnunet.py +++ b/monai/bundle/nnunet.py @@ -152,7 +152,7 @@ def get_nnunet_trainer( class ModelnnUNetWrapper(torch.nn.Module): """ A wrapper class for nnUNet model integration with MONAI framework. - The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference. + The wrapper can be used to integrate the nnUNet Bundle within MONAI framework for inference. Parameters ---------- From 8bbb63b2bde257337105bb25d8af291846d3f8d6 Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 07:37:44 +0000 Subject: [PATCH 18/67] Refactor nnUNet imports for improved module organization --- monai/apps/nnunet/__init__.py | 1 + monai/apps/nnunet/nnunet_bundle.py | 448 +++++++++++++++++++++++++++++ monai/bundle/__init__.py | 1 - 3 files changed, 449 insertions(+), 1 deletion(-) create mode 100644 monai/apps/nnunet/nnunet_bundle.py diff --git a/monai/apps/nnunet/__init__.py b/monai/apps/nnunet/__init__.py index 405a79fe01..214ed1d45d 100644 --- a/monai/apps/nnunet/__init__.py +++ b/monai/apps/nnunet/__init__.py @@ -11,5 +11,6 @@ from __future__ import annotations +from .nnunet_bundle import ModelnnUNetWrapper, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer from .nnunetv2_runner import nnUNetV2Runner from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py new file mode 100644 index 0000000000..711aa81a60 --- /dev/null +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -0,0 +1,448 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import shutil +from pathlib import Path + +import numpy as np +import torch +from torch.backends import cudnn + +from typing import Union, Tuple +from monai.data.meta_tensor import MetaTensor +from monai.utils import optional_import +from nnunetv2.training.logging.nnunet_logger import nnUNetLogger + +join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") +load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") + +__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "ModelnnUNetWrapper"] + + +def get_nnunet_trainer( + dataset_name_or_id, + configuration, + fold, + trainer_class_name="nnUNetTrainer", + plans_identifier="nnUNetPlans", + pretrained_weights=None, + num_gpus=1, + use_compressed_data=False, + export_validation_probabilities=False, + continue_training=False, + only_run_validation=False, + disable_checkpointing=False, + val_with_best=False, + device="cuda", + pretrained_model=None, +): + """ + Get the nnUNet trainer instance based on the provided configuration. + The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, + optimizer, loss function, DataLoader, etc. + + Example:: + + from monai.apps import SupervisedTrainer + from monai.bundle.nnunet import get_nnunet_trainer + + dataset_name_or_id = 'Task101_PROSTATE' + fold = 0 + configuration = '3d_fullres' + nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) + + trainer = SupervisedTrainer( + device=nnunet_trainer.device, + max_epochs=nnunet_trainer.num_epochs, + train_data_loader=nnunet_trainer.dataloader_train, + network=nnunet_trainer.network, + optimizer=nnunet_trainer.optimizer, + loss_function=nnunet_trainer.loss_function, + epoch_length=nnunet_trainer.num_iterations_per_epoch, + ) + + Parameters + ---------- + dataset_name_or_id : Union[str, int] + The name or ID of the dataset to be used. + configuration : str + The configuration name for the training. + fold : Union[int, str] + The fold number or 'all' for cross-validation. + trainer_class_name : str, optional + The class name of the trainer to be used. Default is 'nnUNetTrainer'. + plans_identifier : str, optional + Identifier for the plans to be used. Default is 'nnUNetPlans'. + pretrained_weights : str, optional + Path to the pretrained weights file. + num_gpus : int, optional + Number of GPUs to be used. Default is 1. + use_compressed_data : bool, optional + Whether to use compressed data. Default is False. + export_validation_probabilities : bool, optional + Whether to export validation probabilities. Default is False. + continue_training : bool, optional + Whether to continue training from a checkpoint. Default is False. + only_run_validation : bool, optional + Whether to only run validation. Default is False. + disable_checkpointing : bool, optional + Whether to disable checkpointing. Default is False. + val_with_best : bool, optional + Whether to validate with the best model. Default is False. + device : str, optional + The device to be used for training. Default is 'cuda'. + pretrained_model : str, optional + Path to the pretrained model file. + Returns + ------- + nnunet_trainer + The nnUNet trainer instance. + """ + # From nnUNet/nnunetv2/run/run_training.py#run_training + if isinstance(fold, str): + if fold != "all": + try: + fold = int(fold) + except ValueError as e: + print( + f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!' + ) + raise e + + if int(num_gpus) > 1: + ... # Disable for now + else: + from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint + + nnunet_trainer = get_trainer_from_args( + str(dataset_name_or_id), + configuration, + fold, + trainer_class_name, + plans_identifier, + use_compressed_data, + device=torch.device(device), + ) + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy." + + maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) + nnunet_trainer.on_train_start() # Added to Initialize Trainer + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True + + if pretrained_model is not None: + state_dict = torch.load(pretrained_model) + if "network_weights" in state_dict: + nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) + return nnunet_trainer + + +class ModelnnUNetWrapper(torch.nn.Module): + """ + A wrapper class for nnUNet model integration with MONAI framework. + The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference. + + Parameters + ---------- + predictor : object + The nnUNet predictor object used for inference. + model_folder : str + The folder path where the model and related files are stored. + model_name : str, optional + The name of the model file, by default "model.pt". + + Attributes + ---------- + predictor : nnUNetPredictor + The nnUNet predictor object used for inference. + network_weights : torch.nn.Module + The network weights of the model. + + Notes + ----- + This class integrates nnUNet model with MONAI framework by loading necessary configurations, + restoring network architecture, and setting up the predictor for inference. + """ + + def __init__(self, predictor, model_folder, model_name="model.pt"): + super().__init__() + self.predictor = predictor + + model_training_output_dir = model_folder + use_folds = ["0"] + + from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor + dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json")) + plans = load_json(join(Path(model_training_output_dir).parent, "plans.json")) + plans_manager = PlansManager(plans) + + if isinstance(use_folds, str): + use_folds = [use_folds] + + parameters = [] + for i, f in enumerate(use_folds): + f = str(f) if f != "all" else f + checkpoint = torch.load( + join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + ) + monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) + if i == 0: + trainer_name = checkpoint["trainer_name"] + configuration_name = checkpoint["init_args"]["configuration"] + inference_allowed_mirroring_axes = ( + checkpoint["inference_allowed_mirroring_axes"] + if "inference_allowed_mirroring_axes" in checkpoint.keys() + else None + ) + + if 'network_weights' in monai_checkpoint.keys(): + parameters.append(monai_checkpoint["network_weights"]) + else: + parameters.append(monai_checkpoint) + + configuration_manager = plans_manager.get_configuration(configuration_name) + # restore network + import nnunetv2 + from nnunetv2.utilities.find_class_by_name import recursive_find_python_class + from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels + + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + trainer_class = recursive_find_python_class( + join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, "nnunetv2.training.nnUNetTrainer" + ) + if trainer_class is None: + raise RuntimeError( + f"Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. " + f"Please place it there (in any .py file)!" + ) + network = trainer_class.build_network_architecture( + configuration_manager.network_arch_class_name, + configuration_manager.network_arch_init_kwargs, + configuration_manager.network_arch_init_kwargs_req_import, + num_input_channels, + plans_manager.get_label_manager(dataset_json).num_segmentation_heads, + enable_deep_supervision=False, + ) + + predictor.plans_manager = plans_manager + predictor.configuration_manager = configuration_manager + predictor.list_of_parameters = parameters + predictor.network = network + predictor.dataset_json = dataset_json + predictor.trainer_name = trainer_name + predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes + predictor.label_manager = plans_manager.get_label_manager(dataset_json) + if ( + ("nnUNet_compile" in os.environ.keys()) + and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) + # and not isinstance(predictor.network, OptimizedModule) + ): + print("Using torch.compile") + predictor.network = torch.compile(self.network) + # End Block + self.network_weights = self.predictor.network + + def forward(self, x: Union[MetaTensor, Tuple[MetaTensor]]): + """ + Forward pass for the nnUNet model. + + :no-index: + + Args: + x (Union[MetaTensor, Tuple[MetaTensor]]): Input tensor or a tuple of MetaTensors. If the input is a tuple, + it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch. + + Returns: + MetaTensor: The output tensor with the same metadata as the input. + + Raises: + TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors. + + Notes: + - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple. + - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor. + - The filenames are used to generate predictions using the nnUNet predictor. + - The predictions are converted to torch tensors, with added batch and channel dimensions. + - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. + """ + if isinstance(x, tuple): # if batch is decollated (list of tensors) + properties_or_list_of_properties = [] + image_or_list_of_images = [] + + #for img in x: + #if isinstance(img, MetaTensor): + # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()}) + # image_or_list_of_images.append(img.cpu().numpy()[0,:]) + #else: + # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") + + else: # if batch is collated + if isinstance(x, MetaTensor): + if 'pixdim' in x.meta: + properties_or_list_of_properties = {"spacing": x.meta['pixdim'][0][1:4].numpy().tolist()} + else: + properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} + else: + raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") + image_or_list_of_images = x.cpu().numpy()[0,:] + + # input_files should be a list of file paths, one per modality + prediction_output = self.predictor.predict_from_list_of_npy_arrays( + image_or_list_of_images, + None, + properties_or_list_of_properties, + truncated_ofname=None, + save_probabilities=False, + num_processes=2, + num_processes_segmentation_export=2 + ) + # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax + + out_tensors = [] + for out in prediction_output: # Add batch and channel dimensions + out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) + out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension + + if type(x) is tuple: + return MetaTensor(out_tensor, meta=x[0].meta) + else: + return MetaTensor(out_tensor, meta=x.meta) + + +def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): + """ + Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. + The model folder should contain the following files, created during training: + + - dataset.json: from the nnUNet results folder + - plans.json: from the nnUNet results folder + - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration + - model.pt: The checkpoint file containing the model weights. + + The returned wrapper object can be used for inference with MONAI framework: + + Example:: + + from monai.bundle.nnunet import get_nnunet_monai_predictor + + model_folder = 'path/to/monai_bundle/model' + model_name = 'model.pt' + wrapper = get_nnunet_monai_predictor(model_folder, model_name) + + # Perform inference + input_data = ... + output = wrapper(input_data) + + + Parameters + ---------- + model_folder : str + The folder where the model is stored. + model_name : str, optional + The name of the model file, by default "model.pt". + + Returns + ------- + nnUNetMONAIModelWrapper + A wrapper object that contains the nnUNetPredictor and the loaded model. + """ + + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=False, + device=torch.device("cuda", 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True, + ) + # initializes the network architecture, loads the checkpoint + wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name) + return wrapper + + +def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): + """ + Convert nnUNet model checkpoints and configuration to MONAI bundle format. + + Parameters + ---------- + nnunet_config : dict + Configuration dictionary for nnUNet, containing keys such as 'dataset_name_or_id', 'nnunet_configuration', + 'nnunet_trainer', and 'nnunet_plans'. + bundle_root_folder : str + Root folder where the MONAI bundle will be saved. + fold : int, optional + Fold number of the nnUNet model to be converted, by default 0. + + Returns + ------- + None + """ + + nnunet_trainer = "nnUNetTrainer" + nnunet_plans = "nnUNetPlans" + nnunet_configuration = "3d_fullres" + + if "nnunet_trainer" in nnunet_config: + nnunet_trainer = nnunet_config["nnunet_trainer"] + + if "nnunet_plans" in nnunet_config: + nnunet_plans = nnunet_config["nnunet_plans"] + + if "nnunet_configuration" in nnunet_config: + nnunet_configuration = nnunet_config["nnunet_configuration"] + + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + dataset_name = maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]) + nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath( + dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" + ) + + nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) + nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) + + nnunet_checkpoint = {} + nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] + nnunet_checkpoint["init_args"] = nnunet_checkpoint_final["init_args"] + nnunet_checkpoint["trainer_name"] = nnunet_checkpoint_final["trainer_name"] + + torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth")) + + Path(bundle_root_folder).joinpath("models", f"fold_{fold}").mkdir(parents=True, exist_ok=True) + monai_last_checkpoint = {} + monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"] + torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "model.pt")) + + monai_best_checkpoint = {} + monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"] + torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models",f"fold_{fold}", "best_model.pt")) + + if not os.path.exists(os.path.join(bundle_root_folder, "models", "plans.json")): + shutil.copy( + Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json") + ) + + if not os.path.exists(os.path.join(bundle_root_folder, "models", "dataset.json")): + shutil.copy( + Path(nnunet_model_folder).joinpath("dataset.json"), Path(bundle_root_folder).joinpath("models", "dataset.json") + ) \ No newline at end of file diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 305bf9eb6a..3f3c8d545e 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -13,7 +13,6 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser -from .nnunet import ModelnnUNetWrapper, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer from .properties import InferProperties, MetaProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( From 6e97f39f77e18910363ece52f7e496814013f5ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 07:40:31 +0000 Subject: [PATCH 19/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/nnunet/nnunet_bundle.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 711aa81a60..7fff574490 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -18,10 +18,9 @@ import torch from torch.backends import cudnn -from typing import Union, Tuple +from typing import Union from monai.data.meta_tensor import MetaTensor from monai.utils import optional_import -from nnunetv2.training.logging.nnunet_logger import nnUNetLogger join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") @@ -258,7 +257,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): # End Block self.network_weights = self.predictor.network - def forward(self, x: Union[MetaTensor, Tuple[MetaTensor]]): + def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]): """ Forward pass for the nnUNet model. @@ -291,7 +290,7 @@ def forward(self, x: Union[MetaTensor, Tuple[MetaTensor]]): # image_or_list_of_images.append(img.cpu().numpy()[0,:]) #else: # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - + else: # if batch is collated if isinstance(x, MetaTensor): if 'pixdim' in x.meta: @@ -307,7 +306,7 @@ def forward(self, x: Union[MetaTensor, Tuple[MetaTensor]]): image_or_list_of_images, None, properties_or_list_of_properties, - truncated_ofname=None, + truncated_ofname=None, save_probabilities=False, num_processes=2, num_processes_segmentation_export=2 @@ -441,8 +440,8 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): shutil.copy( Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json") ) - + if not os.path.exists(os.path.join(bundle_root_folder, "models", "dataset.json")): shutil.copy( Path(nnunet_model_folder).joinpath("dataset.json"), Path(bundle_root_folder).joinpath("models", "dataset.json") - ) \ No newline at end of file + ) From a7cad283900b2e22bfda0dfccd4a2c627900ddc7 Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 07:41:02 +0000 Subject: [PATCH 20/67] Update nnUNet import paths and comment out Transposed transform in integration tests --- tests/integration/test_integration_nnunet_bundle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_integration_nnunet_bundle.py b/tests/integration/test_integration_nnunet_bundle.py index 4e04f3f5cf..1023d9a5bf 100644 --- a/tests/integration/test_integration_nnunet_bundle.py +++ b/tests/integration/test_integration_nnunet_bundle.py @@ -20,7 +20,7 @@ from monai.apps.nnunet import nnUNetV2Runner from monai.bundle.config_parser import ConfigParser -from monai.bundle.nnunet import convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer +from monai.apps.nnunet.nnunet_bundle import convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer from monai.data import DataLoader, Dataset, create_test_image_3d from monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged, Transposed from monai.utils import optional_import @@ -129,7 +129,7 @@ def test_nnunet_bundle(self) -> None: post_processing_transforms = Compose( [ Decollated(keys=None, detach=True), - Transposed(keys="pred", indices=[0, 3, 2, 1]), + #Transposed(keys="pred", indices=[0, 3, 2, 1]), SaveImaged( keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred" ), From 0734eb3186a898a59fbcee16c3d4f37857034601 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 07:41:38 +0000 Subject: [PATCH 21/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/integration/test_integration_nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_integration_nnunet_bundle.py b/tests/integration/test_integration_nnunet_bundle.py index 1023d9a5bf..6b44c14f4b 100644 --- a/tests/integration/test_integration_nnunet_bundle.py +++ b/tests/integration/test_integration_nnunet_bundle.py @@ -22,7 +22,7 @@ from monai.bundle.config_parser import ConfigParser from monai.apps.nnunet.nnunet_bundle import convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer from monai.data import DataLoader, Dataset, create_test_image_3d -from monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged, Transposed +from monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged from monai.utils import optional_import from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick From c107015d3132ee6b3d1c8b750c56c6a67d1b937a Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 07:45:04 +0000 Subject: [PATCH 22/67] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: 8bbb63b2bde257337105bb25d8af291846d3f8d6 I, simben , hereby add my Signed-off-by to this commit: a7cad283900b2e22bfda0dfccd4a2c627900ddc7 Signed-off-by: simben --- tests/integration/test_integration_nnunet_bundle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_integration_nnunet_bundle.py b/tests/integration/test_integration_nnunet_bundle.py index 6b44c14f4b..fc755b9ff5 100644 --- a/tests/integration/test_integration_nnunet_bundle.py +++ b/tests/integration/test_integration_nnunet_bundle.py @@ -129,6 +129,7 @@ def test_nnunet_bundle(self) -> None: post_processing_transforms = Compose( [ Decollated(keys=None, detach=True), + # Not needed after reading the data directly from the MONAI Transform #Transposed(keys="pred", indices=[0, 3, 2, 1]), SaveImaged( keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred" From 3b5e80bacb071bf7eb2e034ffbfbd334f693a41e Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 07:54:21 +0000 Subject: [PATCH 23/67] Update documentation for nnUNet Bundle integration --- docs/source/apps.rst | 11 +++++++++++ docs/source/bundle.rst | 11 +---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index cc4cea8c1e..f1f444d13f 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -277,3 +277,14 @@ FastMRIReader .. autoclass:: monai.apps.nnunet.nnUNetV2Runner :members: + +`nnUNet Bundle` +--------------- +.. autoclass:: monai.apps.nnunet.ModelnnUNetWrapper + :members: + :special-members: + +.. autofunction:: monai.apps.nnunet.get_nnunet_trainer +.. autofunction:: monai.apps.nnunet.get_nnunet_monai_predictor +.. autofunction:: monai.apps.nnunet.convert_nnunet_to_monai_bundle + diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 49297e1a3d..4ad3df9be9 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -35,12 +35,6 @@ Model Bundle :special-members: -`nnUNet Bundle` ---------------- -.. autoclass:: ModelnnUNetWrapper - :members: - :special-members: - `Scripts` --------- .. autofunction:: ckpt_export @@ -56,7 +50,4 @@ Model Bundle .. autofunction:: verify_net_in_out .. autofunction:: init_bundle .. autofunction:: push_to_hf_hub -.. autofunction:: update_kwargs -.. autofunction:: get_nnunet_trainer -.. autofunction:: get_nnunet_monai_predictor -.. autofunction:: convert_nnunet_to_monai_bundle +.. autofunction:: update_kwargs \ No newline at end of file From 224e92477a6a295c803eb03363e62887c858f5d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 07:56:17 +0000 Subject: [PATCH 24/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/apps.rst | 1 - docs/source/bundle.rst | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f1f444d13f..e27e30c0bf 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -287,4 +287,3 @@ FastMRIReader .. autofunction:: monai.apps.nnunet.get_nnunet_trainer .. autofunction:: monai.apps.nnunet.get_nnunet_monai_predictor .. autofunction:: monai.apps.nnunet.convert_nnunet_to_monai_bundle - diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 4ad3df9be9..fdf745e951 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -50,4 +50,4 @@ Model Bundle .. autofunction:: verify_net_in_out .. autofunction:: init_bundle .. autofunction:: push_to_hf_hub -.. autofunction:: update_kwargs \ No newline at end of file +.. autofunction:: update_kwargs From 1da18e1700b004f85f79708ef21a858e8a094747 Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 10:41:23 +0000 Subject: [PATCH 25/67] Refactor imports and improve code formatting in nnUNet bundle --- monai/apps/nnunet/__init__.py | 7 ++++- monai/apps/nnunet/nnunet_bundle.py | 31 ++++++++++--------- .../test_integration_nnunet_bundle.py | 10 ++++-- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/monai/apps/nnunet/__init__.py b/monai/apps/nnunet/__init__.py index 214ed1d45d..7467a7d7fa 100644 --- a/monai/apps/nnunet/__init__.py +++ b/monai/apps/nnunet/__init__.py @@ -11,6 +11,11 @@ from __future__ import annotations -from .nnunet_bundle import ModelnnUNetWrapper, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer +from .nnunet_bundle import ( + ModelnnUNetWrapper, + convert_nnunet_to_monai_bundle, + get_nnunet_monai_predictor, + get_nnunet_trainer, +) from .nnunetv2_runner import nnUNetV2Runner from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 7fff574490..f3026e89f2 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -13,12 +13,12 @@ import os import shutil from pathlib import Path +from typing import Union import numpy as np import torch from torch.backends import cudnn -from typing import Union from monai.data.meta_tensor import MetaTensor from monai.utils import optional_import @@ -210,7 +210,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): else None ) - if 'network_weights' in monai_checkpoint.keys(): + if "network_weights" in monai_checkpoint.keys(): parameters.append(monai_checkpoint["network_weights"]) else: parameters.append(monai_checkpoint) @@ -284,22 +284,22 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]): properties_or_list_of_properties = [] image_or_list_of_images = [] - #for img in x: - #if isinstance(img, MetaTensor): - # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()}) - # image_or_list_of_images.append(img.cpu().numpy()[0,:]) - #else: - # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") + # for img in x: + # if isinstance(img, MetaTensor): + # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()}) + # image_or_list_of_images.append(img.cpu().numpy()[0,:]) + # else: + # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") else: # if batch is collated if isinstance(x, MetaTensor): - if 'pixdim' in x.meta: - properties_or_list_of_properties = {"spacing": x.meta['pixdim'][0][1:4].numpy().tolist()} + if "pixdim" in x.meta: + properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} else: properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} else: raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - image_or_list_of_images = x.cpu().numpy()[0,:] + image_or_list_of_images = x.cpu().numpy()[0, :] # input_files should be a list of file paths, one per modality prediction_output = self.predictor.predict_from_list_of_npy_arrays( @@ -309,7 +309,7 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]): truncated_ofname=None, save_probabilities=False, num_processes=2, - num_processes_segmentation_export=2 + num_processes_segmentation_export=2, ) # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax @@ -434,14 +434,15 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): monai_best_checkpoint = {} monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"] - torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models",f"fold_{fold}", "best_model.pt")) + torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "best_model.pt")) if not os.path.exists(os.path.join(bundle_root_folder, "models", "plans.json")): shutil.copy( - Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json") + Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json") ) if not os.path.exists(os.path.join(bundle_root_folder, "models", "dataset.json")): shutil.copy( - Path(nnunet_model_folder).joinpath("dataset.json"), Path(bundle_root_folder).joinpath("models", "dataset.json") + Path(nnunet_model_folder).joinpath("dataset.json"), + Path(bundle_root_folder).joinpath("models", "dataset.json"), ) diff --git a/tests/integration/test_integration_nnunet_bundle.py b/tests/integration/test_integration_nnunet_bundle.py index fc755b9ff5..41ad666f29 100644 --- a/tests/integration/test_integration_nnunet_bundle.py +++ b/tests/integration/test_integration_nnunet_bundle.py @@ -19,8 +19,12 @@ import numpy as np from monai.apps.nnunet import nnUNetV2Runner +from monai.apps.nnunet.nnunet_bundle import ( + convert_nnunet_to_monai_bundle, + get_nnunet_monai_predictor, + get_nnunet_trainer, +) from monai.bundle.config_parser import ConfigParser -from monai.apps.nnunet.nnunet_bundle import convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer from monai.data import DataLoader, Dataset, create_test_image_3d from monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged from monai.utils import optional_import @@ -122,7 +126,7 @@ def test_nnunet_bundle(self) -> None: data_loader = DataLoader(dataset, batch_size=1) input = next(iter(data_loader)) - predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models")) + predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models","fold_0")) pred_batch = predictor(input["image"]) Path(self.sim_dataroot).joinpath("predictions").mkdir(parents=True, exist_ok=True) @@ -130,7 +134,7 @@ def test_nnunet_bundle(self) -> None: [ Decollated(keys=None, detach=True), # Not needed after reading the data directly from the MONAI Transform - #Transposed(keys="pred", indices=[0, 3, 2, 1]), + # Transposed(keys="pred", indices=[0, 3, 2, 1]), SaveImaged( keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred" ), From fe0213684991be27f468a8862238f8816d9ce125 Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 10:43:10 +0000 Subject: [PATCH 26/67] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: 3b5e80bacb071bf7eb2e034ffbfbd334f693a41e I, simben , hereby add my Signed-off-by to this commit: 1da18e1700b004f85f79708ef21a858e8a094747 Signed-off-by: simben --- tests/integration/test_integration_nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_integration_nnunet_bundle.py b/tests/integration/test_integration_nnunet_bundle.py index 41ad666f29..2ef9865c43 100644 --- a/tests/integration/test_integration_nnunet_bundle.py +++ b/tests/integration/test_integration_nnunet_bundle.py @@ -133,7 +133,7 @@ def test_nnunet_bundle(self) -> None: post_processing_transforms = Compose( [ Decollated(keys=None, detach=True), - # Not needed after reading the data directly from the MONAI Transform + # Not needed after reading the data directly from the MONAI LoadImaged Transform # Transposed(keys="pred", indices=[0, 3, 2, 1]), SaveImaged( keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred" From 9f91aaf015b55f5009a3d90a5215bbe3c0a8f021 Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 10:55:22 +0000 Subject: [PATCH 27/67] Fix formatting in test_integration_nnunet_bundle.py for improved readability --- tests/integration/test_integration_nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_integration_nnunet_bundle.py b/tests/integration/test_integration_nnunet_bundle.py index 2ef9865c43..1e117aeffe 100644 --- a/tests/integration/test_integration_nnunet_bundle.py +++ b/tests/integration/test_integration_nnunet_bundle.py @@ -126,7 +126,7 @@ def test_nnunet_bundle(self) -> None: data_loader = DataLoader(dataset, batch_size=1) input = next(iter(data_loader)) - predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models","fold_0")) + predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models", "fold_0")) pred_batch = predictor(input["image"]) Path(self.sim_dataroot).joinpath("predictions").mkdir(parents=True, exist_ok=True) From a2bc247caa45235cc43c3dd850292d2ba56fd979 Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 10:57:49 +0000 Subject: [PATCH 28/67] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: 9f91aaf015b55f5009a3d90a5215bbe3c0a8f021 Signed-off-by: simben --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index f3026e89f2..8821bc0e90 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -55,7 +55,7 @@ def get_nnunet_trainer( from monai.apps import SupervisedTrainer from monai.bundle.nnunet import get_nnunet_trainer - dataset_name_or_id = 'Task101_PROSTATE' + dataset_name_or_id = 'Task009_Spleen' fold = 0 configuration = '3d_fullres' nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) From 23de2bcfa02d98c24ec9303a358b75faf3771dbe Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 19:42:32 +0000 Subject: [PATCH 29/67] Refactor forward method in ModelnnUNetWrapper for clarity and type consistency Signed-off-by: simben --- monai/apps/nnunet/nnunet_bundle.py | 35 +-- monai/bundle/nnunet.py | 427 ----------------------------- 2 files changed, 18 insertions(+), 444 deletions(-) delete mode 100644 monai/bundle/nnunet.py diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 8821bc0e90..880da6607f 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -257,14 +257,14 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): # End Block self.network_weights = self.predictor.network - def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]): + def forward(self, x: MetaTensor) -> MetaTensor: """ Forward pass for the nnUNet model. :no-index: Args: - x (Union[MetaTensor, Tuple[MetaTensor]]): Input tensor or a tuple of MetaTensors. If the input is a tuple, + x (MetaTensor): Input tensor. If the input is a tuple, it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch. Returns: @@ -280,9 +280,9 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]): - The predictions are converted to torch tensors, with added batch and channel dimensions. - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. """ - if isinstance(x, tuple): # if batch is decollated (list of tensors) - properties_or_list_of_properties = [] - image_or_list_of_images = [] + #if isinstance(x, tuple): # if batch is decollated (list of tensors) + # properties_or_list_of_properties = [] + # image_or_list_of_images = [] # for img in x: # if isinstance(img, MetaTensor): @@ -291,15 +291,16 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]): # else: # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - else: # if batch is collated - if isinstance(x, MetaTensor): - if "pixdim" in x.meta: - properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} - else: - properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} + #else: # if batch is collated + if isinstance(x, MetaTensor): + if "pixdim" in x.meta: + properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} else: - raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - image_or_list_of_images = x.cpu().numpy()[0, :] + properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} + else: + raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") + + image_or_list_of_images = x.cpu().numpy()[0, :] # input_files should be a list of file paths, one per modality prediction_output = self.predictor.predict_from_list_of_npy_arrays( @@ -318,10 +319,10 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]): out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension - if type(x) is tuple: - return MetaTensor(out_tensor, meta=x[0].meta) - else: - return MetaTensor(out_tensor, meta=x.meta) + #if type(x) is tuple: + # return MetaTensor(out_tensor, meta=x[0].meta) + #else: + return MetaTensor(out_tensor, meta=x.meta) def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py deleted file mode 100644 index 4eca036c17..0000000000 --- a/monai/bundle/nnunet.py +++ /dev/null @@ -1,427 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import os -import shutil -from pathlib import Path - -import numpy as np -import torch -from torch.backends import cudnn - -from monai.data.meta_tensor import MetaTensor -from monai.utils import optional_import - -join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") -load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") - -__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "ModelnnUNetWrapper"] - - -def get_nnunet_trainer( - dataset_name_or_id, - configuration, - fold, - trainer_class_name="nnUNetTrainer", - plans_identifier="nnUNetPlans", - pretrained_weights=None, - num_gpus=1, - use_compressed_data=False, - export_validation_probabilities=False, - continue_training=False, - only_run_validation=False, - disable_checkpointing=False, - val_with_best=False, - device="cuda", - pretrained_model=None, -): - """ - Get the nnUNet trainer instance based on the provided configuration. - The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, - optimizer, loss function, DataLoader, etc. - - Example:: - - from monai.apps import SupervisedTrainer - from monai.bundle.nnunet import get_nnunet_trainer - - dataset_name_or_id = 'Task101_PROSTATE' - fold = 0 - configuration = '3d_fullres' - nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) - - trainer = SupervisedTrainer( - device=nnunet_trainer.device, - max_epochs=nnunet_trainer.num_epochs, - train_data_loader=nnunet_trainer.dataloader_train, - network=nnunet_trainer.network, - optimizer=nnunet_trainer.optimizer, - loss_function=nnunet_trainer.loss_function, - epoch_length=nnunet_trainer.num_iterations_per_epoch, - ) - - Parameters - ---------- - dataset_name_or_id : Union[str, int] - The name or ID of the dataset to be used. - configuration : str - The configuration name for the training. - fold : Union[int, str] - The fold number or 'all' for cross-validation. - trainer_class_name : str, optional - The class name of the trainer to be used. Default is 'nnUNetTrainer'. - plans_identifier : str, optional - Identifier for the plans to be used. Default is 'nnUNetPlans'. - pretrained_weights : str, optional - Path to the pretrained weights file. - num_gpus : int, optional - Number of GPUs to be used. Default is 1. - use_compressed_data : bool, optional - Whether to use compressed data. Default is False. - export_validation_probabilities : bool, optional - Whether to export validation probabilities. Default is False. - continue_training : bool, optional - Whether to continue training from a checkpoint. Default is False. - only_run_validation : bool, optional - Whether to only run validation. Default is False. - disable_checkpointing : bool, optional - Whether to disable checkpointing. Default is False. - val_with_best : bool, optional - Whether to validate with the best model. Default is False. - device : str, optional - The device to be used for training. Default is 'cuda'. - pretrained_model : str, optional - Path to the pretrained model file. - Returns - ------- - nnunet_trainer - The nnUNet trainer instance. - """ - # From nnUNet/nnunetv2/run/run_training.py#run_training - if isinstance(fold, str): - if fold != "all": - try: - fold = int(fold) - except ValueError as e: - print( - f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!' - ) - raise e - - if int(num_gpus) > 1: - ... # Disable for now - else: - from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint - - nnunet_trainer = get_trainer_from_args( - str(dataset_name_or_id), - configuration, - fold, - trainer_class_name, - plans_identifier, - use_compressed_data, - device=torch.device(device), - ) - if disable_checkpointing: - nnunet_trainer.disable_checkpointing = disable_checkpointing - - assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy." - - maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) - nnunet_trainer.on_train_start() # Added to Initialize Trainer - if torch.cuda.is_available(): - cudnn.deterministic = False - cudnn.benchmark = True - - if pretrained_model is not None: - state_dict = torch.load(pretrained_model) - if "network_weights" in state_dict: - nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) - return nnunet_trainer - - -class ModelnnUNetWrapper(torch.nn.Module): - """ - A wrapper class for nnUNet model integration with MONAI framework. - The wrapper can be used to integrate the nnUNet Bundle within MONAI framework for inference. - - Parameters - ---------- - predictor : object - The nnUNet predictor object used for inference. - model_folder : str - The folder path where the model and related files are stored. - model_name : str, optional - The name of the model file, by default "model.pt". - - Attributes - ---------- - predictor : nnUNetPredictor - The nnUNet predictor object used for inference. - network_weights : torch.nn.Module - The network weights of the model. - - Notes - ----- - This class integrates nnUNet model with MONAI framework by loading necessary configurations, - restoring network architecture, and setting up the predictor for inference. - """ - - def __init__(self, predictor, model_folder, model_name="model.pt"): - super().__init__() - self.predictor = predictor - - model_training_output_dir = model_folder - use_folds = ["0"] - - from nnunetv2.utilities.plans_handling.plans_handler import PlansManager - - # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor - dataset_json = load_json(join(model_training_output_dir, "dataset.json")) - plans = load_json(join(model_training_output_dir, "plans.json")) - plans_manager = PlansManager(plans) - - if isinstance(use_folds, str): - use_folds = [use_folds] - - parameters = [] - for i, f in enumerate(use_folds): - f = str(f) if f != "all" else f - checkpoint = torch.load( - join(model_training_output_dir, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") - ) - monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) - if i == 0: - trainer_name = checkpoint["trainer_name"] - configuration_name = checkpoint["init_args"]["configuration"] - inference_allowed_mirroring_axes = ( - checkpoint["inference_allowed_mirroring_axes"] - if "inference_allowed_mirroring_axes" in checkpoint.keys() - else None - ) - - parameters.append(monai_checkpoint["network_weights"]) - - configuration_manager = plans_manager.get_configuration(configuration_name) - # restore network - import nnunetv2 - from nnunetv2.utilities.find_class_by_name import recursive_find_python_class - from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels - - num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) - trainer_class = recursive_find_python_class( - join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, "nnunetv2.training.nnUNetTrainer" - ) - if trainer_class is None: - raise RuntimeError( - f"Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. " - f"Please place it there (in any .py file)!" - ) - network = trainer_class.build_network_architecture( - configuration_manager.network_arch_class_name, - configuration_manager.network_arch_init_kwargs, - configuration_manager.network_arch_init_kwargs_req_import, - num_input_channels, - plans_manager.get_label_manager(dataset_json).num_segmentation_heads, - enable_deep_supervision=False, - ) - - predictor.plans_manager = plans_manager - predictor.configuration_manager = configuration_manager - predictor.list_of_parameters = parameters - predictor.network = network - predictor.dataset_json = dataset_json - predictor.trainer_name = trainer_name - predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes - predictor.label_manager = plans_manager.get_label_manager(dataset_json) - if ( - ("nnUNet_compile" in os.environ.keys()) - and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) - # and not isinstance(predictor.network, OptimizedModule) - ): - print("Using torch.compile") - predictor.network = torch.compile(self.network) - # End Block - self.network_weights = self.predictor.network - - def forward(self, x): - """ - Forward pass for the nnUNet model. - - :no-index: - - Args: - x (Union[torch.Tensor, Tuple[MetaTensor]]): Input tensor or a tuple of MetaTensors. If the input is a tuple, - it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch. - - Returns: - MetaTensor: The output tensor with the same metadata as the input. - - Raises: - TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors. - - Notes: - - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple. - - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor. - - The filenames are used to generate predictions using the nnUNet predictor. - - The predictions are converted to torch tensors, with added batch and channel dimensions. - - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. - """ - if type(x) is tuple: # if batch is decollated (list of tensors) - input_files = [img.meta["filename_or_obj"][0] for img in x] - else: # if batch is collated - input_files = x.meta["filename_or_obj"] - if isinstance(input_files, str): - input_files = [input_files] - - # input_files should be a list of file paths, one per modality - prediction_output = self.predictor.predict_from_files( - [input_files], - None, - save_probabilities=False, - overwrite=True, - num_processes_preprocessing=2, - num_processes_segmentation_export=2, - folder_with_segs_from_prev_stage=None, - num_parts=1, - part_id=0, - ) - # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax - - out_tensors = [] - for out in prediction_output: # Add batch and channel dimensions - out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) - out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension - - if type(x) is tuple: - return MetaTensor(out_tensor, meta=x[0].meta) - else: - return MetaTensor(out_tensor, meta=x.meta) - - -def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): - """ - Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. - The model folder should contain the following files, created during training: - - - dataset.json: from the nnUNet results folder - - plans.json: from the nnUNet results folder - - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration - - model.pt: The checkpoint file containing the model weights. - - The returned wrapper object can be used for inference with MONAI framework: - - Example:: - - from monai.bundle.nnunet import get_nnunet_monai_predictor - - model_folder = 'path/to/monai_bundle/model' - model_name = 'model.pt' - wrapper = get_nnunet_monai_predictor(model_folder, model_name) - - # Perform inference - input_data = ... - output = wrapper(input_data) - - - Parameters - ---------- - model_folder : str - The folder where the model is stored. - model_name : str, optional - The name of the model file, by default "model.pt". - - Returns - ------- - nnUNetMONAIModelWrapper - A wrapper object that contains the nnUNetPredictor and the loaded model. - """ - - from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor - - predictor = nnUNetPredictor( - tile_step_size=0.5, - use_gaussian=True, - use_mirroring=False, - device=torch.device("cuda", 0), - verbose=False, - verbose_preprocessing=False, - allow_tqdm=True, - ) - # initializes the network architecture, loads the checkpoint - wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name) - return wrapper - - -def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): - """ - Convert nnUNet model checkpoints and configuration to MONAI bundle format. - - Parameters - ---------- - nnunet_config : dict - Configuration dictionary for nnUNet, containing keys such as 'dataset_name_or_id', 'nnunet_configuration', - 'nnunet_trainer', and 'nnunet_plans'. - bundle_root_folder : str - Root folder where the MONAI bundle will be saved. - fold : int, optional - Fold number of the nnUNet model to be converted, by default 0. - - Returns - ------- - None - """ - - nnunet_trainer = "nnUNetTrainer" - nnunet_plans = "nnUNetPlans" - nnunet_configuration = "3d_fullres" - - if "nnunet_trainer" in nnunet_config: - nnunet_trainer = nnunet_config["nnunet_trainer"] - - if "nnunet_plans" in nnunet_config: - nnunet_plans = nnunet_config["nnunet_plans"] - - if "nnunet_configuration" in nnunet_config: - nnunet_configuration = nnunet_config["nnunet_configuration"] - - from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name - - dataset_name = maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]) - nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath( - dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" - ) - - nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) - nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) - - nnunet_checkpoint = {} - nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] - nnunet_checkpoint["init_args"] = nnunet_checkpoint_final["init_args"] - nnunet_checkpoint["trainer_name"] = nnunet_checkpoint_final["trainer_name"] - - torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth")) - - monai_last_checkpoint = {} - monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"] - torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", "model.pt")) - - monai_best_checkpoint = {} - monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"] - torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", "best_model.pt")) - - shutil.copy( - Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json") - ) - shutil.copy( - Path(nnunet_model_folder).joinpath("dataset.json"), Path(bundle_root_folder).joinpath("models", "dataset.json") - ) From 6c324445db3c478c2368979c3041565c4653997c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 19:43:01 +0000 Subject: [PATCH 30/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/nnunet/nnunet_bundle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 880da6607f..be39138831 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -13,7 +13,6 @@ import os import shutil from pathlib import Path -from typing import Union import numpy as np import torch @@ -299,7 +298,7 @@ def forward(self, x: MetaTensor) -> MetaTensor: properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} else: raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - + image_or_list_of_images = x.cpu().numpy()[0, :] # input_files should be a list of file paths, one per modality From f43125ae18f53cd679426bd9cf85741cb8252fff Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 19:57:13 +0000 Subject: [PATCH 31/67] Remove commented-out code in ModelnnUNetWrapper for improved readability Signed-off-by: simben --- monai/apps/nnunet/nnunet_bundle.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 880da6607f..cca708bcab 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -280,18 +280,18 @@ def forward(self, x: MetaTensor) -> MetaTensor: - The predictions are converted to torch tensors, with added batch and channel dimensions. - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. """ - #if isinstance(x, tuple): # if batch is decollated (list of tensors) + # if isinstance(x, tuple): # if batch is decollated (list of tensors) # properties_or_list_of_properties = [] # image_or_list_of_images = [] - # for img in x: - # if isinstance(img, MetaTensor): - # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()}) - # image_or_list_of_images.append(img.cpu().numpy()[0,:]) - # else: - # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") + # for img in x: + # if isinstance(img, MetaTensor): + # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()}) + # image_or_list_of_images.append(img.cpu().numpy()[0,:]) + # else: + # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - #else: # if batch is collated + # else: # if batch is collated if isinstance(x, MetaTensor): if "pixdim" in x.meta: properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} @@ -299,7 +299,7 @@ def forward(self, x: MetaTensor) -> MetaTensor: properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} else: raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - + image_or_list_of_images = x.cpu().numpy()[0, :] # input_files should be a list of file paths, one per modality @@ -319,9 +319,9 @@ def forward(self, x: MetaTensor) -> MetaTensor: out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension - #if type(x) is tuple: + # if type(x) is tuple: # return MetaTensor(out_tensor, meta=x[0].meta) - #else: + # else: return MetaTensor(out_tensor, meta=x.meta) From 30cb6c56ffd9d098e0bd4ac80b839f6a89880702 Mon Sep 17 00:00:00 2001 From: simben Date: Mon, 10 Mar 2025 20:08:23 +0000 Subject: [PATCH 32/67] Comment out the torch.compile line in ModelnnUNetWrapper to prevent potential issues during execution Signed-off-by: simben --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 565cb37238..b976bc944c 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -252,7 +252,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): # and not isinstance(predictor.network, OptimizedModule) ): print("Using torch.compile") - predictor.network = torch.compile(self.network) + # predictor.network = torch.compile(self.network) # End Block self.network_weights = self.predictor.network From 1a30a0bb6f6557c0b9fcb165bd712edc1adfbb0a Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Mon, 24 Mar 2025 09:04:57 +0000 Subject: [PATCH 33/67] Add JSON generation widgets for nnUNet and update requirements --- monai/apps/nnunet/__init__.py | 4 +- monai/apps/nnunet/nnunet_bundle.py | 372 ++++-- monai/nvflare/__init__.py | 10 + monai/nvflare/json_generator.py | 179 +++ monai/nvflare/nnunet_executor.py | 329 +++++ monai/nvflare/nvflare_generate_job_configs.py | 1082 +++++++++++++++++ monai/nvflare/nvflare_nnunet.py | 688 +++++++++++ monai/nvflare/response_processor.py | 342 ++++++ requirements-dev.txt | 2 + 9 files changed, 2894 insertions(+), 114 deletions(-) create mode 100644 monai/nvflare/__init__.py create mode 100644 monai/nvflare/json_generator.py create mode 100644 monai/nvflare/nnunet_executor.py create mode 100644 monai/nvflare/nvflare_generate_job_configs.py create mode 100644 monai/nvflare/nvflare_nnunet.py create mode 100644 monai/nvflare/response_processor.py diff --git a/monai/apps/nnunet/__init__.py b/monai/apps/nnunet/__init__.py index 7467a7d7fa..991de8d281 100644 --- a/monai/apps/nnunet/__init__.py +++ b/monai/apps/nnunet/__init__.py @@ -12,10 +12,12 @@ from __future__ import annotations from .nnunet_bundle import ( - ModelnnUNetWrapper, + convert_monai_bundle_to_nnunet, convert_nnunet_to_monai_bundle, + get_network_from_nnunet_plans, get_nnunet_monai_predictor, get_nnunet_trainer, + nnUNetMONAIModelWrapper, ) from .nnunetv2_runner import nnUNetV2Runner from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index b976bc944c..2b4b59a5c1 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -16,6 +16,7 @@ import numpy as np import torch +from torch._dynamo import OptimizedModule from torch.backends import cudnn from monai.data.meta_tensor import MetaTensor @@ -24,7 +25,7 @@ join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") -__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "ModelnnUNetWrapper"] +__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "nnUNetMONAIModelWrapper"] def get_nnunet_trainer( @@ -41,7 +42,7 @@ def get_nnunet_trainer( only_run_validation=False, disable_checkpointing=False, val_with_best=False, - device="cuda", + device=torch.device("cuda"), pretrained_model=None, ): """ @@ -49,25 +50,25 @@ def get_nnunet_trainer( The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, optimizer, loss function, DataLoader, etc. - Example:: + ```python + from monai.apps import SupervisedTrainer + from monai.bundle.nnunet import get_nnunet_trainer - from monai.apps import SupervisedTrainer - from monai.bundle.nnunet import get_nnunet_trainer + dataset_name_or_id = 'Task101_PROSTATE' + fold = 0 + configuration = '3d_fullres' + nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) - dataset_name_or_id = 'Task009_Spleen' - fold = 0 - configuration = '3d_fullres' - nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) + trainer = SupervisedTrainer( + device=nnunet_trainer.device, + max_epochs=nnunet_trainer.num_epochs, + train_data_loader=nnunet_trainer.dataloader_train, + network=nnunet_trainer.network, + optimizer=nnunet_trainer.optimizer, + loss_function=nnunet_trainer.loss_function, + epoch_length=nnunet_trainer.num_iterations_per_epoch, - trainer = SupervisedTrainer( - device=nnunet_trainer.device, - max_epochs=nnunet_trainer.num_epochs, - train_data_loader=nnunet_trainer.dataloader_train, - network=nnunet_trainer.network, - optimizer=nnunet_trainer.optimizer, - loss_function=nnunet_trainer.loss_function, - epoch_length=nnunet_trainer.num_iterations_per_epoch, - ) + ``` Parameters ---------- @@ -97,7 +98,7 @@ def get_nnunet_trainer( Whether to disable checkpointing. Default is False. val_with_best : bool, optional Whether to validate with the best model. Default is False. - device : str, optional + device : torch.device, optional The device to be used for training. Default is 'cuda'. pretrained_model : str, optional Path to the pretrained model file. @@ -129,7 +130,7 @@ def get_nnunet_trainer( trainer_class_name, plans_identifier, use_compressed_data, - device=torch.device(device), + device=device, ) if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing @@ -149,7 +150,7 @@ def get_nnunet_trainer( return nnunet_trainer -class ModelnnUNetWrapper(torch.nn.Module): +class nnUNetMONAIModelWrapper(torch.nn.Module): """ A wrapper class for nnUNet model integration with MONAI framework. The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference. @@ -162,14 +163,16 @@ class ModelnnUNetWrapper(torch.nn.Module): The folder path where the model and related files are stored. model_name : str, optional The name of the model file, by default "model.pt". - Attributes ---------- - predictor : nnUNetPredictor - The nnUNet predictor object used for inference. + predictor : object + The predictor object used for inference. network_weights : torch.nn.Module The network weights of the model. - + Methods + ------- + forward(x) + Perform forward pass and prediction on the input data. Notes ----- This class integrates nnUNet model with MONAI framework by loading necessary configurations, @@ -181,13 +184,13 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): self.predictor = predictor model_training_output_dir = model_folder - use_folds = ["0"] + use_folds = "0" from nnunetv2.utilities.plans_handling.plans_handler import PlansManager - # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor - dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json")) - plans = load_json(join(Path(model_training_output_dir).parent, "plans.json")) + ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor + dataset_json = load_json(join(model_training_output_dir, "dataset.json")) + plans = load_json(join(model_training_output_dir, "plans.json")) plans_manager = PlansManager(plans) if isinstance(use_folds, str): @@ -195,11 +198,10 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): parameters = [] for i, f in enumerate(use_folds): - f = str(f) if f != "all" else f + f = int(f) if f != "all" else f checkpoint = torch.load( - join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + join(model_training_output_dir, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") ) - monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) if i == 0: trainer_name = checkpoint["trainer_name"] configuration_name = checkpoint["init_args"]["configuration"] @@ -208,11 +210,14 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): if "inference_allowed_mirroring_axes" in checkpoint.keys() else None ) - - if "network_weights" in monai_checkpoint.keys(): - parameters.append(monai_checkpoint["network_weights"]) - else: - parameters.append(monai_checkpoint) + if Path(model_training_output_dir).joinpath(f"fold_{f}", model_name).is_file(): + monai_checkpoint = torch.load( + join(model_training_output_dir, model_name), map_location=torch.device("cpu") + ) + if "network_weights" in monai_checkpoint.keys(): + parameters.append(monai_checkpoint["network_weights"]) + else: + parameters.append(monai_checkpoint) configuration_manager = plans_manager.get_configuration(configuration_name) # restore network @@ -249,67 +254,32 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): if ( ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) - # and not isinstance(predictor.network, OptimizedModule) + and not isinstance(predictor.network, OptimizedModule) ): print("Using torch.compile") - # predictor.network = torch.compile(self.network) - # End Block + predictor.network = torch.compile(self.network) + ## End Block self.network_weights = self.predictor.network - def forward(self, x: MetaTensor) -> MetaTensor: - """ - Forward pass for the nnUNet model. - - :no-index: - - Args: - x (MetaTensor): Input tensor. If the input is a tuple, - it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch. - - Returns: - MetaTensor: The output tensor with the same metadata as the input. - - Raises: - TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors. - - Notes: - - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple. - - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor. - - The filenames are used to generate predictions using the nnUNet predictor. - - The predictions are converted to torch tensors, with added batch and channel dimensions. - - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. - """ - # if isinstance(x, tuple): # if batch is decollated (list of tensors) - # properties_or_list_of_properties = [] - # image_or_list_of_images = [] - - # for img in x: - # if isinstance(img, MetaTensor): - # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()}) - # image_or_list_of_images.append(img.cpu().numpy()[0,:]) - # else: - # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - - # else: # if batch is collated - if isinstance(x, MetaTensor): - if "pixdim" in x.meta: - properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} - else: - properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} - else: - raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - - image_or_list_of_images = x.cpu().numpy()[0, :] + def forward(self, x): + if type(x) is tuple: # if batch is decollated (list of tensors) + input_files = [img.meta["filename_or_obj"][0] for img in x] + else: # if batch is collated + input_files = x.meta["filename_or_obj"] + if type(input_files) is str: + input_files = [input_files] # input_files should be a list of file paths, one per modality - prediction_output = self.predictor.predict_from_list_of_npy_arrays( - image_or_list_of_images, + prediction_output = self.predictor.predict_from_files( + [input_files], None, - properties_or_list_of_properties, - truncated_ofname=None, save_probabilities=False, - num_processes=2, + overwrite=True, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, + num_parts=1, + part_id=0, ) # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax @@ -318,36 +288,35 @@ def forward(self, x: MetaTensor) -> MetaTensor: out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension - # if type(x) is tuple: - # return MetaTensor(out_tensor, meta=x[0].meta) - # else: - return MetaTensor(out_tensor, meta=x.meta) + if type(x) is tuple: + return MetaTensor(out_tensor, meta=x[0].meta) + else: + return MetaTensor(out_tensor, meta=x.meta) def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): """ - Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. + Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor. The model folder should contain the following files, created during training: - - - dataset.json: from the nnUNet results folder - - plans.json: from the nnUNet results folder - - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration - - model.pt: The checkpoint file containing the model weights. + - dataset.json: from the nnUNet results folder. + - plans.json: from the nnUNet results folder. + - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration + (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`). + - model.pt: The checkpoint file containing the model weights. The returned wrapper object can be used for inference with MONAI framework: + ```python + from monai.bundle.nnunet import get_nnunet_monai_predictor - Example:: - - from monai.bundle.nnunet import get_nnunet_monai_predictor + model_folder = 'path/to/monai_bundle/model' + model_name = 'model.pt' + wrapper = get_nnunet_monai_predictor(model_folder, model_name) - model_folder = 'path/to/monai_bundle/model' - model_name = 'model.pt' - wrapper = get_nnunet_monai_predictor(model_folder, model_name) - - # Perform inference - input_data = ... - output = wrapper(input_data) + # Perform inference + input_data = ... + output = wrapper(input_data) + ``` Parameters ---------- @@ -374,7 +343,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): allow_tqdm=True, ) # initializes the network architecture, loads the checkpoint - wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name) + wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name) return wrapper @@ -427,14 +396,13 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth")) - Path(bundle_root_folder).joinpath("models", f"fold_{fold}").mkdir(parents=True, exist_ok=True) monai_last_checkpoint = {} monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"] - torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "model.pt")) + torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", "model.pt")) monai_best_checkpoint = {} monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"] - torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "best_model.pt")) + torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", "best_model.pt")) if not os.path.exists(os.path.join(bundle_root_folder, "models", "plans.json")): shutil.copy( @@ -446,3 +414,181 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): Path(nnunet_model_folder).joinpath("dataset.json"), Path(bundle_root_folder).joinpath("models", "dataset.json"), ) + + +def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model_ckpt=None, model_key_in_ckpt="model"): + """ + Load and initialize a neural network based on nnUNet plans and configuration. + + Parameters + ---------- + plans_file : str + Path to the JSON file containing the nnUNet plans. + dataset_file : str + Path to the JSON file containing the dataset information. + configuration : str + The configuration name to be used from the plans. + model_ckpt : str, optional + Path to the model checkpoint file. If None, the network is returned without loading weights (default is None). + model_key_in_ckpt : str, optional + The key in the checkpoint file that contains the model state dictionary (default is "model"). + + Returns + ------- + network : torch.nn.Module + The initialized neural network, with weights loaded if `model_ckpt` is provided. + """ + from batchgenerators.utilities.file_and_folder_operations import load_json + from nnunetv2.utilities.get_network_from_plans import get_network_from_plans + from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels + from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + plans = load_json(plans_file) + dataset_json = load_json(dataset_file) + + plans_manager = PlansManager(plans) + configuration_manager = plans_manager.get_configuration(configuration) + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + label_manager = plans_manager.get_label_manager(dataset_json) + + enable_deep_supervision = True + + network = get_network_from_plans( + configuration_manager.network_arch_class_name, + configuration_manager.network_arch_init_kwargs, + configuration_manager.network_arch_init_kwargs_req_import, + num_input_channels, + label_manager.num_segmentation_heads, + allow_init=True, + deep_supervision=enable_deep_supervision, + ) + + if model_ckpt is None: + return network + else: + state_dict = torch.load(model_ckpt) + network.load_state_dict(state_dict[model_key_in_ckpt]) + return network + + +def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0): + """ + Convert a MONAI bundle to nnU-Net format. + + Parameters + ---------- + nnunet_config : dict + Configuration dictionary for nnU-Net. Expected keys are: + - "dataset_name_or_id": str, name or ID of the dataset. + - "nnunet_trainer": str, optional, name of the nnU-Net trainer (default is "nnUNetTrainer"). + - "nnunet_plans": str, optional, name of the nnU-Net plans (default is "nnUNetPlans"). + bundle_root_folder : str + Path to the root folder of the MONAI bundle. + fold : int, optional + Fold number for cross-validation (default is 0). + + Returns + ------- + None + """ + from odict import odict + + nnunet_trainer = "nnUNetTrainer" + nnunet_plans = "nnUNetPlans" + + if "nnunet_trainer" in nnunet_config: + nnunet_trainer = nnunet_config["nnunet_trainer"] + + if "nnunet_plans" in nnunet_config: + nnunet_plans = nnunet_config["nnunet_plans"] + + from nnunetv2.training.logging.nnunet_logger import nnUNetLogger + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + def subfiles(folder, join: bool = True, prefix: str = None, suffix: str = None, sort: bool = True): + + if join: + l = os.path.join # noqa: E741 + else: + l = lambda x, y: y # noqa: E741, E731 + res = [ + l(folder, i.name) + for i in Path(folder).iterdir() + if i.is_file() + and (prefix is None or i.name.startswith(prefix)) + and (suffix is None or i.name.endswith(suffix)) + ] + if sort: + res.sort() + return res + + nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath( + maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]), + f"{nnunet_trainer}__{nnunet_plans}__3d_fullres", + ) + + nnunet_preprocess_model_folder = Path(os.environ["nnUNet_preprocessed"]).joinpath( + maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]) + ) + + Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True) + + nnunet_checkpoint = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") + latest_checkpoints = subfiles( + Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True, join=False + ) + epochs = [] + for latest_checkpoint in latest_checkpoints: + epochs.append(int(latest_checkpoint[len("checkpoint_epoch=") : -len(".pt")])) + + epochs.sort() + final_epoch = epochs[-1] + monai_last_checkpoint = torch.load(f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt") + + best_checkpoints = subfiles( + Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), + prefix="checkpoint_key_metric", + sort=True, + join=False, + ) + key_metrics = [] + for best_checkpoint in best_checkpoints: + key_metrics.append(str(best_checkpoint[len("checkpoint_key_metric=") : -len(".pt")])) + + key_metrics.sort() + best_key_metric = key_metrics[-1] + monai_best_checkpoint = torch.load( + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt" + ) + + nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"] + + nnunet_checkpoint["network_weights"] = odict() + + for key in monai_last_checkpoint["network_weights"]: + nnunet_checkpoint["network_weights"][key] = monai_last_checkpoint["network_weights"][key] + + nnunet_checkpoint["current_epoch"] = final_epoch + nnunet_checkpoint["logging"] = nnUNetLogger().get_checkpoint() + nnunet_checkpoint["_best_ema"] = 0 + nnunet_checkpoint["grad_scaler_state"] = None + + torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) + + nnunet_checkpoint["network_weights"] = odict() + + nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"] + + for key in monai_best_checkpoint["network_weights"]: + nnunet_checkpoint["network_weights"][key] = monai_best_checkpoint["network_weights"][key] + + torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) + + if not os.path.exists(os.path.join(nnunet_model_folder, "dataset.json")): + shutil.copy(f"{bundle_root_folder}/models/dataset.json", nnunet_model_folder) + if not os.path.exists(os.path.join(nnunet_model_folder, "plans.json")): + shutil.copy(f"{bundle_root_folder}/models/plans.json", nnunet_model_folder) + if not os.path.exists(os.path.join(nnunet_model_folder, "dataset_fingerprint.json")): + shutil.copy(f"{nnunet_preprocess_model_folder}/dataset_fingerprint.json", nnunet_model_folder) + if not os.path.exists(os.path.join(nnunet_model_folder, "nnunet_checkpoint.pth")): + shutil.copy(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", nnunet_model_folder) diff --git a/monai/nvflare/__init__.py b/monai/nvflare/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/nvflare/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/nvflare/json_generator.py b/monai/nvflare/json_generator.py new file mode 100644 index 0000000000..9326a35837 --- /dev/null +++ b/monai/nvflare/json_generator.py @@ -0,0 +1,179 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import os.path + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_context import FLContext +from nvflare.widgets.widget import Widget + + +class PrepareJsonGenerator(Widget): + """ + A widget class to prepare and generate a JSON file containing data preparation configurations. + + Parameters + ---------- + results_dir : str, optional + The directory where the results will be stored (default is "prepare"). + json_file_name : str, optional + The name of the JSON file to be generated (default is "data_dict.json"). + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events during the federated learning process. Clears the data preparation configuration + at the start of a run and saves the configuration to a JSON file at the end of a run. + """ + + def __init__(self, results_dir="prepare", json_file_name="data_dict.json"): + super(PrepareJsonGenerator, self).__init__() + + self._results_dir = results_dir + self._data_prepare_config = {} + self._json_file_name = json_file_name + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._data_prepare_config.clear() + elif event_type == EventType.END_RUN: + self._data_prepare_config = fl_ctx.get_prop("client_data_dict", None) + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + data_prepare_res_dir = os.path.join(run_dir, self._results_dir) + if not os.path.exists(data_prepare_res_dir): + os.makedirs(data_prepare_res_dir) + + res_file_path = os.path.join(data_prepare_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(self._data_prepare_config, f) + + +class nnUNetPackageReportJsonGenerator(Widget): + """ + A class to generate JSON reports for nnUNet package. + + Parameters + ---------- + results_dir : str, optional + Directory where the report will be saved (default is "package_report"). + json_file_name : str, optional + Name of the JSON file to save the report (default is "package_report.json"). + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events to clear the report at the start of a run and save the report at the end of a run. + """ + + def __init__(self, results_dir="package_report", json_file_name="package_report.json"): + super(nnUNetPackageReportJsonGenerator, self).__init__() + + self._results_dir = results_dir + self._report = {} + self._json_file_name = json_file_name + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._report.clear() + elif event_type == EventType.END_RUN: + datasets = fl_ctx.get_prop("package_report", None) + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + cross_val_res_dir = os.path.join(run_dir, self._results_dir) + if not os.path.exists(cross_val_res_dir): + os.makedirs(cross_val_res_dir) + + res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(datasets, f) + + +class nnUNetPlansJsonGenerator(Widget): + """ + A class to generate JSON files for nnUNet plans. + + Parameters + ---------- + results_dir : str, optional + Directory where the preprocessing results will be stored (default is "nnUNet_preprocessing"). + json_file_name : str, optional + Name of the JSON file to be generated (default is "nnUNetPlans.json"). + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves + the plans to a JSON file at the end of a run. + """ + + def __init__(self, results_dir="nnUNet_preprocessing", json_file_name="nnUNetPlans.json"): + + super(nnUNetPlansJsonGenerator, self).__init__() + + self._results_dir = results_dir + self._nnUNetPlans = {} + self._json_file_name = json_file_name + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._nnUNetPlans.clear() + elif event_type == EventType.END_RUN: + datasets = fl_ctx.get_prop("nnunet_plans", None) + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + cross_val_res_dir = os.path.join(run_dir, self._results_dir) + if not os.path.exists(cross_val_res_dir): + os.makedirs(cross_val_res_dir) + + res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(datasets, f) + + +class nnUNetValSummaryJsonGenerator(Widget): + """ + A widget to generate a JSON summary for nnUNet validation results. + + Parameters + ---------- + results_dir : str, optional + Directory where the nnUNet training results are stored (default is "nnUNet_train"). + json_file_name : str, optional + Name of the JSON file to save the validation summary (default is "val_summary.json"). + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves + the validation summary to a JSON file at the end of a run. + """ + + def __init__(self, results_dir="nnUNet_train", json_file_name="val_summary.json"): + + super(nnUNetValSummaryJsonGenerator, self).__init__() + + self._results_dir = results_dir + self._nnUNetPlans = {} + self._json_file_name = json_file_name + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._nnUNetPlans.clear() + elif event_type == EventType.END_RUN: + datasets = fl_ctx.get_prop("val_summary_dict", None) + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + cross_val_res_dir = os.path.join(run_dir, self._results_dir) + if not os.path.exists(cross_val_res_dir): + os.makedirs(cross_val_res_dir) + + res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(datasets, f) diff --git a/monai/nvflare/nnunet_executor.py b/monai/nvflare/nnunet_executor.py new file mode 100644 index 0000000000..12f21f678c --- /dev/null +++ b/monai/nvflare/nnunet_executor.py @@ -0,0 +1,329 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +from nvflare.apis.dxo import DXO, DataKind +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal + +from monai.nvflare.nvflare_nnunet import ( # check_host_config, + check_packages, + plan_and_preprocess, + prepare_bundle, + prepare_data_folder, + preprocess, + train, +) + + +class nnUNetExecutor(Executor): + """ + nnUNetExecutor is a class that handles the execution of various tasks related to nnUNet training and preprocessing + within the NVFlare framework. + + Parameters + ---------- + data_dir : str, optional + Directory where the data is stored. + modality_dict : dict, optional + Dictionary containing modality information. + prepare_task_name : str, optional + Name of the task for preparing the dataset. + check_client_packages_task_name : str, optional + Name of the task for checking client packages. + plan_and_preprocess_task_name : str, optional + Name of the task for planning and preprocessing. + preprocess_task_name : str, optional + Name of the task for preprocessing. + training_task_name : str, optional + Name of the task for training. + prepare_bundle_name : str, optional + Name of the task for preparing the bundle. + subfolder_suffix : str, optional + Suffix for subfolders. + dataset_format : str, optional + Format of the dataset, default is "subfolders". + patient_id_in_file_identifier : bool, optional + Whether patient ID is in file identifier, default is True. + nnunet_config : dict, optional + Configuration dictionary for nnUNet. + nnunet_root_folder : str, optional + Root folder for nnUNet. + client_name : str, optional + Name of the client. + tracking_uri : str, optional + URI for tracking. + mlflow_token : str, optional + Token for MLflow. + bundle_root : str, optional + Root directory for the bundle. + train_extra_configs : dict, optional + Extra configurations for training. + exclude_vars : list, optional + List of variables to exclude. + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events triggered during the federated learning process. + initialize(fl_ctx: FLContext) + Initializes the executor with the given federated learning context. + execute(task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable + Executes the specified task. + prepare_dataset() -> Shareable + Prepares the dataset for training. + check_packages_installed() -> Shareable + Checks if the required packages are installed. + plan_and_preprocess() -> Shareable + Plans and preprocesses the dataset. + preprocess() -> Shareable + Preprocesses the dataset. + train() -> Shareable + Trains the model. + prepare_bundle() -> Shareable + Prepares the bundle for deployment. + """ + + def __init__( + self, + data_dir=None, + modality_dict=None, + prepare_task_name="prepare", + check_client_packages_task_name="check_client_packages", + plan_and_preprocess_task_name="plan_and_preprocess", + preprocess_task_name="preprocess", + training_task_name="train", + prepare_bundle_name="prepare_bundle", + subfolder_suffix=None, + dataset_format="subfolders", + patient_id_in_file_identifier=True, + nnunet_config=None, + nnunet_root_folder=None, + client_name=None, + tracking_uri=None, + mlflow_token=None, + bundle_root=None, + train_extra_configs=None, + exclude_vars=None, + ): + super().__init__() + + self.exclude_vars = exclude_vars + self.prepare_task_name = prepare_task_name + self.data_dir = data_dir + self.subfolder_suffix = subfolder_suffix + self.patient_id_in_file_identifier = patient_id_in_file_identifier + self.dataset_format = dataset_format + self.modality_dict = modality_dict + self.nnunet_config = nnunet_config + self.nnunet_root_folder = nnunet_root_folder + self.client_name = client_name + self.tracking_uri = tracking_uri + self.mlflow_token = mlflow_token + self.check_client_packages_task_name = check_client_packages_task_name + self.plan_and_preprocess_task_name = plan_and_preprocess_task_name + self.preprocess_task_name = preprocess_task_name + self.training_task_name = training_task_name + self.prepare_bundle_name = prepare_bundle_name + self.bundle_root = bundle_root + self.train_extra_configs = train_extra_configs + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.initialize(fl_ctx) + + def initialize(self, fl_ctx: FLContext): + self.run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + self.root_dir = fl_ctx.get_engine().get_workspace().root_dir + self.custom_app_dir = fl_ctx.get_engine().get_workspace().get_app_custom_dir(fl_ctx.get_job_id()) + + with open("init_logfile_out.log", "w") as f_o: + with open("init_logfile_err.log", "w") as f_e: + subprocess.call( + [ + sys.executable, + "-m", + "pip", + "install", + "--user", + "-r", + str(Path(self.custom_app_dir).joinpath("requirements.txt")), + ], + stdout=f_o, + stderr=f_e, + ) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + self.root_dir = fl_ctx.get_engine().get_workspace().root_dir + self.custom_app_dir = fl_ctx.get_engine().get_workspace().get_app_custom_dir(fl_ctx.get_job_id()) + try: + if task_name == self.prepare_task_name: + return self.prepare_dataset() + elif task_name == self.check_client_packages_task_name: + return self.check_packages_installed() + elif task_name == self.plan_and_preprocess_task_name: + return self.plan_and_preprocess() + elif task_name == self.preprocess_task_name: + return self.preprocess() + elif task_name == self.training_task_name: + return self.train() + elif task_name == self.prepare_bundle_name: + return self.prepare_bundle() + else: + return make_reply(ReturnCode.TASK_UNKNOWN) + except Exception as e: + self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + def prepare_dataset(self) -> Shareable: + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + data_list = prepare_data_folder( + data_dir=self.data_dir, + nnunet_root_dir=self.nnunet_root_folder, + dataset_name_or_id=self.nnunet_config["dataset_name_or_id"], + modality_dict=self.modality_dict, + experiment_name=self.nnunet_config["experiment_name"], + client_name=self.client_name, + dataset_format=self.dataset_format, + patient_id_in_file_identifier=self.patient_id_in_file_identifier, + tracking_uri=self.tracking_uri, + mlflow_token=self.mlflow_token, + subfolder_suffix=self.subfolder_suffix, + trainer_class_name=nnunet_trainer_name, + ) + + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=data_list, meta={}) + return outgoing_dxo.to_shareable() + + def check_packages_installed(self): + packages = [ + "nvflare", + # {"package_name":'pymaia-learn',"import_name":"PyMAIA"}, + "torch", + "monai", + "numpy", + "nnunetv2", + ] + package_report = check_packages(packages) + + # host_config = check_host_config() + # package_report.update(host_config) + + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=package_report, meta={}) + + return outgoing_dxo.to_shareable() + + def plan_and_preprocess(self): + if "nnunet_plans" not in self.nnunet_config: + nnunet_plans_name = "nnUNetPlans" + else: + nnunet_plans_name = self.nnunet_config["nnunet_plans"] + + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + nnunet_plans = plan_and_preprocess( + self.nnunet_root_folder, + self.nnunet_config["dataset_name_or_id"], + self.client_name, + self.nnunet_config["experiment_name"], + self.tracking_uri, + nnunet_plans_name=nnunet_plans_name, + trainer_class_name=nnunet_trainer_name, + ) + + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=nnunet_plans, meta={}) + return outgoing_dxo.to_shareable() + + def preprocess(self): + if "nnunet_plans" not in self.nnunet_config: + nnunet_plans_name = "nnUNetPlans" + else: + nnunet_plans_name = self.nnunet_config["nnunet_plans"] + + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + nnunet_plans = preprocess( + self.nnunet_root_folder, + self.nnunet_config["dataset_name_or_id"], + nnunet_plans_file_path=Path(self.custom_app_dir).joinpath(f"{nnunet_plans_name}.json"), + trainer_class_name=nnunet_trainer_name, + ) + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=nnunet_plans, meta={}) + return outgoing_dxo.to_shareable() + + def train(self): + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + if "nnunet_plans" not in self.nnunet_config: + nnunet_plans_name = "nnUNetPlans" + else: + nnunet_plans_name = self.nnunet_config["nnunet_plans"] + + validation_summary = train( + self.nnunet_root_folder, + trainer_class_name=nnunet_trainer_name, + fold=0, + experiment_name=self.nnunet_config["experiment_name"], + client_name=self.client_name, + tracking_uri=self.tracking_uri, + nnunet_plans_name=nnunet_plans_name, + dataset_name_or_id=self.nnunet_config["dataset_name_or_id"], + run_with_bundle=True if self.bundle_root is not None else False, + bundle_root=self.bundle_root, + ) + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=validation_summary, meta={}) + return outgoing_dxo.to_shareable() + + def prepare_bundle(self): + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + if "nnunet_plans" not in self.nnunet_config: + nnunet_plans_name = "nnUNetPlans" + else: + nnunet_plans_name = self.nnunet_config["nnunet_plans"] + + bundle_config = { + "bundle_root": self.bundle_root, + "tracking_uri": self.tracking_uri, + "mlflow_experiment_name": "FedLearning-" + self.nnunet_config["experiment_name"], + "mlflow_run_name": self.client_name, + "nnunet_plans_identifier": nnunet_plans_name, + "nnunet_trainer_class_name": nnunet_trainer_name, + } + + prepare_bundle(bundle_config, self.train_extra_configs) + + return make_reply(ReturnCode.OK) diff --git a/monai/nvflare/nvflare_generate_job_configs.py b/monai/nvflare/nvflare_generate_job_configs.py new file mode 100644 index 0000000000..b8c6e709d9 --- /dev/null +++ b/monai/nvflare/nvflare_generate_job_configs.py @@ -0,0 +1,1082 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import subprocess +from pathlib import Path + +import yaml +from pyhocon import ConfigFactory +from pyhocon.converter import HOCONConverter + + +def prepare_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Prepare configuration files for nnUNet dataset preparation using NVFlare. + + Parameters + ---------- + clients : dict + Dictionary containing client-specific configurations. Each key is a client ID and the value is a dictionary + with the following keys: + - "data_dir": str, path to the client's data directory. + - "patient_id_in_file_identifier": str, identifier for patient ID in file. + - "modality_dict": dict, dictionary mapping modalities. + - "dataset_format": str, format of the dataset. + - "nnunet_root_folder": str, path to the nnUNet root folder. + - "client_name": str, name of the client. + - "subfolder_suffix": str, optional, suffix for subfolders. + experiment : dict + Dictionary containing experiment-specific configurations with the following keys: + - "dataset_name_or_id": str, name or ID of the dataset. + - "experiment_name": str, name of the experiment. + - "tracking_uri": str, URI for tracking. + - "mlflow_token": str, optional, token for MLflow. + root_dir : str + Root directory where the configuration files will be generated. + script_dir : str + Directory containing the scripts. + nvflare_exec : str + Path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "prepare" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Prepare nnUNet Dataset", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPrepareProcessor", "args": {}}, + {"id": "json_generator", "path": "monai.nvflare.json_generator.PrepareJsonGenerator", "args": {}}, + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "data_dir": clients[client_id]["data_dir"], + "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], + "modality_dict": clients[client_id]["modality_dict"], + "dataset_format": clients[client_id]["dataset_format"], + "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], + "nnunet_config": { + "dataset_name_or_id": experiment["dataset_name_or_id"], + "experiment_name": experiment["experiment_name"], + }, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "subfolder_suffix" in clients[client_id]: + client["executors"][0]["executor"]["args"]["subfolder_suffix"] = clients[client_id]["subfolder_suffix"] + if "mlflow_token" in experiment: + client["executors"][0]["executor"]["args"]["mlflow_token"] = experiment["mlflow_token"] + + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def check_client_packages_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generate job configuration files for checking client packages in an NVFlare experiment. + + Parameters + ---------- + clients : dict + A dictionary where keys are client IDs and values are client details. + experiment : str + The name of the experiment. + root_dir : str + The root directory where the configuration files will be generated. + script_dir : str + The directory containing the necessary scripts for NVFlare. + nvflare_exec : str + The NVFlare executable path. + + Returns + ------- + None + """ + task_name = "check_client_packages" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = { + "description": "Check Python Packages and Report", + "client_category": "Executor", + "controller_type": "server", + } + + meta = { + "name": f"{task_name}", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "nnunet_processor", + "path": "monai.nvflare.response_processor.nnUNetPackageReportProcessor", + "args": {}, + }, + { + "id": "json_generator", + "path": "monai.nvflare.json_generator.nnUNetPackageReportJsonGenerator", + "args": {}, + }, + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + {"tasks": [task_name], "executor": {"path": "monai.nvflare.nnunet_executor.nnUNetExecutor", "args": {}}} + ], + } + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def plan_and_preprocess_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generates and writes configuration files for the plan and preprocess task in the nnUNet experiment. + + Parameters + ---------- + clients : dict + A dictionary containing client-specific configurations. Each key is a client ID, and the value is + another dictionary with client-specific settings. + experiment : dict + A dictionary containing experiment-specific configurations such as dataset name, experiment name, + tracking URI, and optional nnUNet plans and trainer. + root_dir : str + The root directory where the configuration files will be generated. + script_dir : str + The directory containing the scripts to be used in the NVFlare job. + nvflare_exec : str + The path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "plan_and_preprocess" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Plan and Preprocess nnUNet", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPlanProcessor", "args": {}}, + {"id": "json_generator", "path": "monai.nvflare.json_generator.nnUNetPlansJsonGenerator", "args": {}}, + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "data_dir": clients[client_id]["data_dir"], + "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], + "modality_dict": clients[client_id]["modality_dict"], + "dataset_format": clients[client_id]["dataset_format"], + "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], + "nnunet_config": { + "dataset_name_or_id": experiment["dataset_name_or_id"], + "experiment_name": experiment["experiment_name"], + }, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def preprocess_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generate job configuration files for the preprocessing task in NVFlare. + + Parameters + ---------- + clients : dict + A dictionary containing client-specific configurations. Each key is a client ID, and the value is a dictionary + with the following keys: + - 'data_dir': str, path to the client's data directory. + - 'patient_id_in_file_identifier': str, identifier for patient ID in the file. + - 'modality_dict': dict, dictionary mapping modalities. + - 'dataset_format': str, format of the dataset. + - 'nnunet_root_folder': str, root folder for nnUNet. + - 'client_name': str, name of the client. + experiment : dict + A dictionary containing experiment-specific configurations with the following keys: + - 'dataset_name_or_id': str, name or ID of the dataset. + - 'experiment_name': str, name of the experiment. + - 'tracking_uri': str, URI for tracking. + - 'nnunet_plans' (optional): str, nnUNet plans. + - 'nnunet_trainer' (optional): str, nnUNet trainer. + root_dir : str + The root directory where the configuration files will be generated. + script_dir : str + The directory containing the scripts to be used in the job. + nvflare_exec : str + The NVFlare executable to be used for creating the job. + + Returns + ------- + None + """ + task_name = "preprocess" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Preprocess nnUNet", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPlanProcessor", "args": {}} + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "data_dir": clients[client_id]["data_dir"], + "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], + "modality_dict": clients[client_id]["modality_dict"], + "dataset_format": clients[client_id]["dataset_format"], + "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], + "nnunet_config": { + "dataset_name_or_id": experiment["dataset_name_or_id"], + "experiment_name": experiment["experiment_name"], + }, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def train_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generate training configuration files for nnUNet using NVFlare. + + Parameters + ---------- + clients : dict + Dictionary containing client-specific configurations. Each key is a client ID, and the value is a dictionary + with the following keys: + - 'data_dir': str, path to the client's data directory. + - 'patient_id_in_file_identifier': str, identifier for patient ID in file. + - 'modality_dict': dict, dictionary mapping modalities. + - 'dataset_format': str, format of the dataset. + - 'nnunet_root_folder': str, path to the nnUNet root folder. + - 'client_name': str, name of the client. + - 'bundle_root': str, optional, path to the bundle root directory. + experiment : dict + Dictionary containing experiment-specific configurations with the following keys: + - 'dataset_name_or_id': str, name or ID of the dataset. + - 'experiment_name': str, name of the experiment. + - 'tracking_uri': str, URI for tracking. + - 'nnunet_plans': str, optional, nnUNet plans. + - 'nnunet_trainer': str, optional, nnUNet trainer. + root_dir : str + Root directory where the configuration files will be generated. + script_dir : str + Directory containing the scripts to be used. + nvflare_exec : str + Path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "train" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Train nnUNet", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetTrainProcessor", "args": {}}, + {"id": "json_generator", "path": "monai.nvflare.json_generator.nnUNetValSummaryJsonGenerator", "args": {}}, + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "data_dir": clients[client_id]["data_dir"], + "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], + "modality_dict": clients[client_id]["modality_dict"], + "dataset_format": clients[client_id]["dataset_format"], + "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], + "nnunet_config": { + "dataset_name_or_id": experiment["dataset_name_or_id"], + "experiment_name": experiment["experiment_name"], + }, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + if "bundle_root" in clients[client_id]: + client["executors"][0]["executor"]["args"]["bundle_root"] = clients[client_id]["bundle_root"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def prepare_bundle_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Prepare the configuration files for the nnUNet bundle and generate the job configurations for NVFlare. + + Parameters + ---------- + clients : dict + A dictionary containing client information. Keys are client IDs and values are dictionaries with client details. + experiment : dict + A dictionary containing experiment details such as 'experiment_name', 'tracking_uri', and optional + configurations like 'bundle_extra_config', 'nnunet_plans', and 'nnunet_trainer'. + root_dir : str + The root directory where the configuration files and job directories will be created. + script_dir : str + The directory containing the necessary scripts for NVFlare. + nvflare_exec : str + The path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "prepare_bundle" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Prepare nnUNet Bundle", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "nnunet_processor", + "path": "monai.nvflare.response_processor.nnUNetBundlePrepareProcessor", + "args": {}, + } + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 600000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "nnunet_config": {"experiment_name": experiment["experiment_name"]}, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "bundle_extra_config" in experiment: + client["executors"][0]["executor"]["args"]["train_extra_configs"] = experiment["bundle_extra_config"] + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + if "bundle_root" in clients[client_id]: + client["executors"][0]["executor"]["args"]["bundle_root"] = clients[client_id]["bundle_root"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def train_fl_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generate federated learning job configurations for NVFlare. + + Parameters + ---------- + clients : dict + Dictionary containing client names and their configurations. + experiment : dict + Dictionary containing experiment parameters such as number of rounds and local epochs. + root_dir : str + Root directory where the job configurations will be saved. + script_dir : str + Directory containing the necessary scripts for NVFlare. + nvflare_exec : str + Path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "train_fl_nnunet_bundle" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = { + "description": "Federated Learning with nnUNet-MONAI Bundle", + "client_category": "Executor", + "controller_type": "server", + } + + meta = { + "name": f"{task_name}", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": len(list(clients.keys())), + "mandatory_clients": list(clients.keys()), + } + + for client_name, client_config in clients.items(): + meta["deploy_map"][f"{task_name}-{client_name}"] = [client_name] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "min_clients": len(list(clients.keys())), + "num_rounds": experiment["num_rounds"], + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "persistor", + "path": "monai_nvflare.monai_bundle_persistor.MonaiBundlePersistor", + "args": { + "bundle_root": experiment["server_bundle_root"], + "config_train_filename": "configs/train.yaml", + "network_def_key": "network_def_fl", + }, + }, + {"id": "shareable_generator", "name": "FullModelShareableGenerator", "args": {}}, + { + "id": "aggregator", + "name": "InTimeAccumulateWeightedAggregator", + "args": {"expected_data_kind": "WEIGHT_DIFF"}, + }, + {"id": "model_selector", "name": "IntimeModelSelector", "args": {}}, + {"id": "model_locator", "name": "PTFileModelLocator", "args": {"pt_persistor_id": "persistor"}}, + {"id": "json_generator", "name": "ValidationJsonGenerator", "args": {}}, + ], + "workflows": [ + { + "id": "scatter_gather_ctl", + "name": "ScatterAndGather", + "args": { + "min_clients": "{min_clients}", + "num_rounds": "{num_rounds}", + "start_round": experiment["start_round"], + "wait_time_after_min_received": 10, + "aggregator_id": "aggregator", + "persistor_id": "persistor", + "shareable_generator_id": "shareable_generator", + "train_task_name": "train", + "train_timeout": 0, + }, + }, + { + "id": "cross_site_model_eval", + "name": "CrossSiteModelEval", + "args": { + "model_locator_id": "model_locator", + "submit_model_timeout": 600, + "validation_timeout": 6000, + "cleanup_models": True, + }, + }, + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_name, client_config in clients.items(): + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "executors": [ + { + "tasks": ["train", "submit_model", "validate"], + "executor": { + "id": "executor", + # "path": "monai_algo.ClientnnUNetAlgoExecutor", + "path": "monai_nvflare.client_algo_executor.ClientAlgoExecutor", + "args": {"client_algo_id": "client_algo", "key_metric": "Val_Dice"}, + }, + } + ], + "components": [ + { + "id": "client_algo", + # "path": "monai_algo.MonaiAlgonnUNet", + "path": "monai.fl.client.monai_algo.MonaiAlgo", + "args": { + "bundle_root": client_config["bundle_root"], + "config_train_filename": "configs/train.yaml", + "save_dict_key": "network_weights", + "local_epochs": experiment["local_epochs"], + "train_kwargs": {"nnunet_root_folder": client_config["nnunet_root_folder"]}, + }, + } + ], + } + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-{client_name}").mkdir(parents=True, exist_ok=True) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-{client_name}", "config_fed_client.conf"), "w" + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def generate_configs(client_files, experiment_file, script_dir, job_dir, nvflare_exec="nvflare"): + """ + Generate configuration files for NVFlare job. + + Parameters + ---------- + client_files : list of str + List of file paths to client configuration files. + experiment_file : str + File path to the experiment configuration file. + script_dir : str + Directory path where the scripts are located. + job_dir : str + Directory path where the job configurations will be saved. + nvflare_exec : str, optional + NVFlare executable command, by default "nvflare". + + Returns + ------- + None + """ + clients = {} + for client_id in client_files: + with open(client_id) as f: + client_name = Path(client_id).name + clients[client_name.split(".")[0]] = yaml.safe_load(f) + + with open(experiment_file) as f: + experiment = yaml.safe_load(f) + + check_client_packages_config(clients, experiment, job_dir, script_dir, nvflare_exec) + prepare_config(clients, experiment, job_dir, script_dir, nvflare_exec) + plan_and_preprocess_config(clients, experiment, job_dir, script_dir, nvflare_exec) + preprocess_config(clients, experiment, job_dir, script_dir, nvflare_exec) + train_config(clients, experiment, job_dir, script_dir, nvflare_exec) + prepare_bundle_config(clients, experiment, job_dir, script_dir, nvflare_exec) + train_fl_config(clients, experiment, job_dir, script_dir, nvflare_exec) diff --git a/monai/nvflare/nvflare_nnunet.py b/monai/nvflare/nvflare_nnunet.py new file mode 100644 index 0000000000..724c6c64df --- /dev/null +++ b/monai/nvflare/nvflare_nnunet.py @@ -0,0 +1,688 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import logging +import multiprocessing +import os +import pathlib +import random +import re +import shutil +import subprocess +from importlib.metadata import version +from pathlib import Path + +import mlflow +import numpy as np +import pandas as pd +import psutil +import yaml + +import monai +from monai.apps.nnunet import nnUNetV2Runner +from monai.apps.nnunet.nnunet_bundle import convert_monai_bundle_to_nnunet +from monai.bundle import ConfigParser + + +def train( + nnunet_root_dir, + experiment_name, + client_name, + tracking_uri, + dataset_name_or_id, + trainer_class_name="nnUNetTrainer", + nnunet_plans_name="nnUNetPlans", + run_with_bundle=False, + fold=0, + bundle_root=None, + mlflow_token=None, +): + """ + + Train a nnUNet model and log metrics to MLflow. + + Parameters + ---------- + nnunet_root_dir : str + Root directory for nnUNet. + experiment_name : str + Name of the MLflow experiment. + client_name : str + Name of the client. + tracking_uri : str + URI for MLflow tracking server. + dataset_name_or_id : str + Name or ID of the dataset. + trainer_class_name : str, optional + Name of the nnUNet trainer class, by default "nnUNetTrainer". + nnunet_plans_name : str, optional + Name of the nnUNet plans, by default "nnUNetPlans". + run_with_bundle : bool, optional + Whether to run with MONAI bundle, by default False. + fold : int, optional + Fold number for cross-validation, by default 0. + bundle_root : str, optional + Root directory for MONAI bundle, by default None. + mlflow_token : str, optional + Token for MLflow authentication, by default None. + + Returns + ------- + dict + Dictionary containing validation summary metrics. + """ + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") + runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) + + if not run_with_bundle: + runner.train_single_model(config="3d_fullres", fold=fold) + else: + os.environ["BUNDLE_ROOT"] = bundle_root + os.environ["PYTHONPATH"] = os.environ["PYTHONPATH"] + ":" + bundle_root + monai.bundle.run( + config_file=Path(bundle_root).joinpath("configs/train.yaml"), + bundle_root=bundle_root, + nnunet_trainer_class_name=trainer_class_name, + mlflow_experiment_name=experiment_name, + mlflow_run_name="run_" + client_name, + tracking_uri=tracking_uri, + fold_id=fold, + ) + nnunet_config = {"dataset_name_or_id": dataset_name_or_id, "nnunet_trainer": trainer_class_name} + convert_monai_bundle_to_nnunet(nnunet_config, bundle_root) + runner.train_single_model(config="3d_fullres", fold=fold, val="") + + if mlflow_token is not None: + os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + try: + mlflow.create_experiment(experiment_name) + except Exception as e: + print(e) + mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) + + filter = f""" + tags."client" = "{client_name}" + """ + + runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) + + validation_summary = os.path.join( + runner.nnunet_results, + runner.dataset_name, + f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", + f"fold_{fold}", + "validation", + "summary.json", + ) + + dataset_file = os.path.join( + runner.nnunet_results, + runner.dataset_name, + f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", + "dataset.json", + ) + + with open(dataset_file, "r") as f: + dataset_dict = json.load(f) + labels = dataset_dict["labels"] + labels = {str(v): k for k, v in labels.items()} + + with open(validation_summary, "r") as f: + validation_summary_dict = json.load(f) + + if len(runs) == 0: + with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): + for label in validation_summary_dict["mean"]: + for metric in validation_summary_dict["mean"][label]: + label_name = labels[label] + mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric])) + + else: + with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): + for label in validation_summary_dict["mean"]: + for metric in validation_summary_dict["mean"][label]: + label_name = labels[label] + mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric])) + + return validation_summary_dict + + +def preprocess(nnunet_root_dir, dataset_name_or_id, nnunet_plans_file_path=None, trainer_class_name="nnUNetTrainer"): + """ + Preprocess the dataset for nnUNet training. + + Parameters + ---------- + nnunet_root_dir : str + The root directory of the nnUNet project. + dataset_name_or_id : str or int + The name or ID of the dataset to preprocess. + nnunet_plans_file_path : Path, optional + The file path to the nnUNet plans file. If None, default plans will be used. Default is None. + trainer_class_name : str, optional + The name of the trainer class to use. Default is "nnUNetTrainer". + + Returns + ------- + dict + The nnUNet plans dictionary. + """ + + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") + runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) + + nnunet_plans_name = nnunet_plans_file_path.name.split(".")[0] + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id)) + + Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name).mkdir(parents=True, exist_ok=True) + + shutil.copy( + Path(nnunet_root_dir).joinpath("nnUNet_raw_data_base", dataset_name, "dataset.json"), + Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name, "dataset.json"), + ) + if nnunet_plans_file_path is not None: + with open(nnunet_plans_file_path, "r") as f: + nnunet_plans = json.load(f) + nnunet_plans["dataset_name"] = dataset_name + json.dump( + nnunet_plans, + open( + Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name, f"{nnunet_plans_name}.json"), + "w", + ), + indent=4, + ) + + runner.extract_fingerprints(npfp=2, verify_dataset_integrity=True) + runner.preprocess(c=["3d_fullres"], n_proc=[2], overwrite_plans_name=nnunet_plans_name) + + return nnunet_plans + + +def plan_and_preprocess( + nnunet_root_dir, + dataset_name_or_id, + client_name, + experiment_name, + tracking_uri, + mlflow_token=None, + nnunet_plans_name="nnUNetPlans", + trainer_class_name="nnUNetTrainer", +): + """ + Plan and preprocess the dataset using nnUNetV2Runner and log the plans to MLflow. + + Parameters + ---------- + nnunet_root_dir : str + The root directory of nnUNet. + dataset_name_or_id : str or int + The name or ID of the dataset to be processed. + client_name : str + The name of the client. + experiment_name : str + The name of the MLflow experiment. + tracking_uri : str + The URI of the MLflow tracking server. + mlflow_token : str, optional + The token for MLflow authentication (default is None). + nnunet_plans_name : str, optional + The name of the nnUNet plans (default is "nnUNetPlans"). + trainer_class_name : str, optional + The name of the nnUNet trainer class (default is "nnUNetTrainer"). + + Returns + ------- + dict + The nnUNet plans as a dictionary. + """ + + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") + + runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) + + runner.plan_and_process( + npfp=2, verify_dataset_integrity=True, c=["3d_fullres"], n_proc=[2], overwrite_plans_name=nnunet_plans_name + ) + + preprocessed_folder = runner.nnunet_preprocessed + + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id)) + + with open(Path(preprocessed_folder).joinpath(f"{dataset_name}", nnunet_plans_name + ".json"), "r") as f: + nnunet_plans = json.load(f) + + if mlflow_token is not None: + os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + try: + mlflow.create_experiment(experiment_name) + except Exception as e: + print(e) + mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) + + filter = f""" + tags."client" = "{client_name}" + """ + + runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) + + if len(runs) == 0: + with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): + mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json") + + else: + with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): + mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json") + + return nnunet_plans + + +def prepare_data_folder( + data_dir, + nnunet_root_dir, + dataset_name_or_id, + modality_dict, + experiment_name, + client_name, + dataset_format, + tracking_uri=None, + mlflow_token=None, + subfolder_suffix=None, + patient_id_in_file_identifier=True, + trainer_class_name="nnUNetTrainer", +): + """ + Prepare the data folder for nnUNet training and log the data to MLflow. + + Parameters + ---------- + data_dir : str + Directory containing the dataset. + nnunet_root_dir : str + Root directory for nnUNet. + dataset_name_or_id : str + Name or ID of the dataset. + modality_dict : dict + Dictionary mapping modality IDs to file suffixes. + experiment_name : str + Name of the MLflow experiment. + client_name : str + Name of the client. + dataset_format : str + Format of the dataset. Supported formats are "subfolders", "decathlon", and "nnunet". + tracking_uri : str, optional + URI for MLflow tracking server. + mlflow_token : str, optional + Token for MLflow authentication. + subfolder_suffix : str, optional + Suffix for subfolder names. + patient_id_in_file_identifier : bool, optional + Whether patient ID is included in file identifier. Default is True. + trainer_class_name : str, optional + Name of the nnUNet trainer class. Default is "nnUNetTrainer". + + Returns + ------- + dict + Dictionary containing the training and testing data lists. + """ + if dataset_format == "subfolders": + if subfolder_suffix is not None: + data_list = { + "training": [ + { + modality_id: ( + str( + pathlib.Path(f.name).joinpath( + f.name[: -len(subfolder_suffix)] + modality_dict[modality_id] + ) + ) + if patient_id_in_file_identifier + else str(pathlib.Path(f.name).joinpath(modality_dict[modality_id])) + ) + for modality_id in modality_dict + } + for f in os.scandir(data_dir) + if f.is_dir() + ], + "testing": [], + } + else: + data_list = { + "training": [ + { + modality_id: ( + str(pathlib.Path(f.name).joinpath(f.name + modality_dict[modality_id])) + if patient_id_in_file_identifier + else str(pathlib.Path(f.name).joinpath(modality_dict[modality_id])) + ) + for modality_id in modality_dict + } + for f in os.scandir(data_dir) + if f.is_dir() + ], + "testing": [], + } + elif dataset_format == "decathlon" or dataset_format == "nnunet": + cases = [] + + for f in os.scandir(Path(data_dir).joinpath("imagesTr")): + if f.is_file(): + for modality_suffix in list(modality_dict.values()): + if f.name.endswith(modality_suffix) and modality_suffix != ".nii.gz": + cases.append(f.name[: -len(modality_suffix)]) + if len(np.unique(list(modality_dict.values()))) == 1 and ".nii.gz" in list(modality_dict.values()): + cases.append(f.name[: -len(".nii.gz")]) + cases = np.unique(cases) + data_list = { + "training": [ + { + modality_id: str(Path("imagesTr").joinpath(case + modality_dict[modality_id])) + for modality_id in modality_dict + if modality_id != "label" + } + for case in cases + ], + "testing": [], + } + for idx, case in enumerate(data_list["training"]): + modality_id = list(modality_dict.keys())[0] + case_id = Path(case[modality_id]).name[: -len(modality_dict[modality_id])] + data_list["training"][idx]["label"] = str(Path("labelsTr").joinpath(case_id + modality_dict["label"])) + else: + raise ValueError("Dataset format not supported") + + for idx, train_case in enumerate(data_list["training"]): + for modality_id in modality_dict: + data_list["training"][idx][modality_id + "_is_file"] = ( + Path(data_dir).joinpath(data_list["training"][idx][modality_id]).is_file() + ) + if "image" not in data_list["training"][idx] and modality_id != "label": + data_list["training"][idx]["image"] = data_list["training"][idx][modality_id] + data_list["training"][idx]["fold"] = 0 + + random.seed(42) + random.shuffle(data_list["training"]) + + data_list["testing"] = [data_list["training"][0]] + + num_folds = 5 + fold_size = len(data_list["training"]) // num_folds + for i in range(num_folds): + for j in range(fold_size): + data_list["training"][i * fold_size + j]["fold"] = i + + datalist_file = Path(data_dir).joinpath(f"{experiment_name}_folds.json") + with open(datalist_file, "w", encoding="utf-8") as f: + json.dump(data_list, f, ensure_ascii=False, indent=4) + + os.makedirs(nnunet_root_dir, exist_ok=True) + + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") + data_src = { + "modality": [k for k in modality_dict.keys() if k != "label"], + "dataset_name_or_id": dataset_name_or_id, + "datalist": str(datalist_file), + "dataroot": str(data_dir), + } + + ConfigParser.export_config_file(data_src, data_src_cfg) + + if dataset_format != "nnunet": + runner = nnUNetV2Runner( + input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir + ) + runner.convert_dataset() + else: + ... + + if mlflow_token is not None: + os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + try: + mlflow.create_experiment(experiment_name) + mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) + except Exception as e: + print(e) + mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) + + filter = f""" + tags."client" = "{client_name}" + """ + + runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) + + try: + if len(runs) == 0: + with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): + mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json") + else: + with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): + mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json") + except (BrokenPipeError, ConnectionError) as e: + logging.error(f"Failed to log data to MLflow: {e}") + + return data_list + + +def check_packages(packages): + """ + Check if the specified packages are installed and return a report. + + Parameters + ---------- + packages : list + A list of package names (str) or dictionaries with keys "import_name" and "package_name". + + Returns + ------- + dict + A dictionary where the keys are package names and the values are strings indicating whether + the package is installed and its version if applicable. + + Examples + -------- + >>> check_packages(["numpy", "nonexistent_package"]) + {'numpy': 'numpy 1.21.0 is installed.', 'nonexistent_package': 'nonexistent_package is not installed.'} + >>> check_packages([{"import_name": "torch", "package_name": "torch"}]) + {'torch': 'torch 1.9.0 is installed.'} + """ + report = {} + for package in packages: + try: + if isinstance(package, dict): + __import__(package["import_name"]) + package_version = version(package["package_name"]) + name = package["package_name"] + print(f"{name} {package_version} is installed.") + report[name] = f"{name} {package_version} is installed." + else: + + __import__(package) + package_version = version(package) + print(f"{package} {package_version} is installed.") + report[package] = f"{package} {package_version} is installed." + + except ImportError: + print(f"{package} is not installed.") + report[package] = f"{package} is not installed." + + return report + + +def check_host_config(): + """ + Collects and returns the host configuration details including GPU, CPU, and memory information. + + Returns + ------- + dict + A dictionary containing the following keys and their corresponding values: + - Config values from `monai.config.deviceconfig.get_config_values()` + - Optional config values from `monai.config.deviceconfig.get_optional_config_values()` + - GPU information including number of GPUs, CUDA version, cuDNN version, and GPU names and memory + - CPU core count + - Total memory in GB + - Memory usage percentage + """ + params_dict = {} + config_values = monai.config.deviceconfig.get_config_values() + for k in config_values: + params_dict[re.sub("[()]", " ", str(k))] = config_values[k] + optional_config_values = monai.config.deviceconfig.get_optional_config_values() + + for k in optional_config_values: + params_dict[re.sub("[()]", " ", str(k))] = optional_config_values[k] + + gpu_info = monai.config.deviceconfig.get_gpu_info() + allowed_keys = ["Num GPUs", "Has Cuda", "CUDA Version", "cuDNN enabled", "cuDNN Version"] + for i in range(gpu_info["Num GPUs"]): + allowed_keys.append(f"GPU {i} Name") + allowed_keys.append(f"GPU {i} Total memory GB ") + + for k in gpu_info: + if re.sub("[()]", " ", str(k)) in allowed_keys: + params_dict[re.sub("[()]", " ", str(k))] = str(gpu_info[k]) + + with open("nvidia-smi.log", "w") as f_e: + subprocess.run("nvidia-smi", stderr=f_e, stdout=f_e) + + params_dict["CPU_Cores"] = multiprocessing.cpu_count() + + vm = psutil.virtual_memory() + + params_dict["Total Memory"] = vm.total / (1024 * 1024 * 1024) + params_dict["Memory Used %"] = vm.percent + + return params_dict + + +def prepare_bundle(bundle_config, train_extra_configs=None): + """ + Prepare the bundle configuration for training and evaluation. + + Parameters + ---------- + bundle_config : dict + Dictionary containing the bundle configuration. Expected keys are: + - "bundle_root": str, root directory of the bundle. + - "tracking_uri": str, URI for tracking. + - "mlflow_experiment_name": str, name of the MLflow experiment. + - "mlflow_run_name": str, name of the MLflow run. + - "nnunet_plans_identifier": str, optional, identifier for nnUNet plans. + - "nnunet_trainer_class_name": str, optional, class name for nnUNet trainer. + train_extra_configs : dict, optional + Additional configurations for training. If provided, expected keys are: + - "resume_epoch": int, epoch to resume training from. + - Any other key-value pairs to be added to the training configuration. + + Returns + ------- + None + """ + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml")) as f: + train_config = yaml.safe_load(f) + train_config["bundle_root"] = bundle_config["bundle_root"] + train_config["tracking_uri"] = bundle_config["tracking_uri"] + train_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"] + train_config["mlflow_run_name"] = bundle_config["mlflow_run_name"] + + train_config["data_src_cfg"] = "$@nnunet_root_folder+'/data_src_cfg.yaml'" + train_config["runner"] = { + "_target_": "nnUNetV2Runner", + "input_config": "$@data_src_cfg", + "trainer_class_name": "@nnunet_trainer_class_name", + "work_dir": "@nnunet_root_folder", + } + + train_config["network"] = "$@nnunet_trainer.network._orig_mod" + + train_handlers = train_config["train_handlers"]["handlers"] + + for idx, handler in enumerate(train_handlers): + if handler["_target_"] == "ValidationHandler": + train_handlers.pop(idx) + break + + train_config["train_handlers"]["handlers"] = train_handlers + + if train_extra_configs is not None and "resume_epoch" in train_extra_configs: + resume_epoch = train_extra_configs["resume_epoch"] + train_config["initialize"] = [ + "$monai.utils.set_determinism(seed=123)", + "$@runner.dataset_name_or_id", + f"$src.trainer.reload_checkpoint(@train#trainer, {resume_epoch}, @iterations, @ckpt_dir, @lr_scheduler)", + ] + else: + train_config["initialize"] = ["$monai.utils.set_determinism(seed=123)", "$@runner.dataset_name_or_id"] + + if "Val_Dice" in train_config["val_key_metric"]: + train_config["val_key_metric"] = {"Val_Dice_Local": train_config["val_key_metric"]["Val_Dice"]} + + if "Val_Dice_per_class" in train_config["val_additional_metrics"]: + train_config["val_additional_metrics"] = { + "Val_Dice_per_class_Local": train_config["val_additional_metrics"]["Val_Dice_per_class"] + } + if "nnunet_plans_identifier" in bundle_config: + train_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"] + + if "nnunet_trainer_class_name" in bundle_config: + train_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"] + + if train_extra_configs is not None: + for key in train_extra_configs: + train_config[key] = train_extra_configs[key] + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.json"), "w") as f: + json.dump(train_config, f) + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml"), "w") as f: + yaml.dump(train_config, f) + + if not Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml").exists(): + shutil.copy( + Path(bundle_config["bundle_root"]).joinpath("nnUNet", "evaluator", "evaluator.yaml"), + Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), + ) + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml")) as f: + evaluate_config = yaml.safe_load(f) + evaluate_config["bundle_root"] = bundle_config["bundle_root"] + + evaluate_config["tracking_uri"] = bundle_config["tracking_uri"] + evaluate_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"] + evaluate_config["mlflow_run_name"] = bundle_config["mlflow_run_name"] + + if "nnunet_plans_identifier" in bundle_config: + evaluate_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"] + if "nnunet_trainer_class_name" in bundle_config: + evaluate_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"] + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.json"), "w") as f: + json.dump(evaluate_config, f) + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), "w") as f: + yaml.dump(evaluate_config, f) diff --git a/monai/nvflare/response_processor.py b/monai/nvflare/response_processor.py new file mode 100644 index 0000000000..a02d307220 --- /dev/null +++ b/monai/nvflare/response_processor.py @@ -0,0 +1,342 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from nvflare.apis.client import Client +from nvflare.apis.dxo import DataKind, from_shareable +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_common.abstract.response_processor import ResponseProcessor + + +class nnUNetPrepareProcessor(ResponseProcessor): + """ + A processor class for preparing nnUNet data in a federated learning context. + + Methods + ------- + __init__(): + Initializes the nnUNetPrepareProcessor with an empty data dictionary. + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable: + Creates and returns a Shareable object for the given task name. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + Processes the response from a client. Validates the response and updates the data dictionary if valid. + final_process(fl_ctx: FLContext) -> bool: + Finalizes the processing by setting the client data dictionary in the federated learning context. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + self.data_dict = {} + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + data_dict = dxo.data + + if not data_dict: + self.log_error(fl_ctx, f"No dataset_dict found from client {client.name}") + return False + + self.data_dict[client.name] = data_dict + + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + if not self.data_dict: + self.log_error(fl_ctx, "no data_prepare_dict from clients") + return False + + # must set sticky to True so other controllers can get it! + fl_ctx.set_prop("client_data_dict", self.data_dict, private=True, sticky=True) + return True + + +class nnUNetPackageReportProcessor(ResponseProcessor): + """ + A processor for handling nnUNet package reports in a federated learning context. + + Attributes + ---------- + package_report : dict + A dictionary to store package reports from clients. + + Methods + ------- + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable + Creates task data for a given task name and federated learning context. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool + Processes the response from a client for a given task name and federated learning context. + final_process(fl_ctx: FLContext) -> bool + Final processing step to handle the collected package reports. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + self.package_report = {} + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + package_report = dxo.data + + if not package_report: + self.log_error(fl_ctx, f"No package_report found from client {client.name}") + return False + + self.package_report[client.name] = package_report + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + if not self.package_report: + self.log_error(fl_ctx, "no plan_dict from client") + return False + + # must set sticky to True so other controllers can get it! + fl_ctx.set_prop("package_report", self.package_report, private=True, sticky=True) + return True + + +class nnUNetPlanProcessor(ResponseProcessor): + """ + nnUNetPlanProcessor is a class that processes responses from clients in a federated learning context. + It inherits from the ResponseProcessor class and is responsible for handling and validating the + responses, extracting the necessary data, and storing it for further use. + + Attributes + ---------- + plan_dict : dict + A dictionary to store the plan data received from clients. + + Methods + ------- + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable + Creates and returns a Shareable object for the given task name. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool + Processes the response from a client, validates it, and stores the plan data if valid. + final_process(fl_ctx: FLContext) -> bool + Finalizes the processing by setting the plan data in the federated learning context. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + self.plan_dict = {} + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + plan_dict = dxo.data + + if not plan_dict: + self.log_error(fl_ctx, f"No plan_dict found from client {client.name}") + return False + + self.plan_dict[client.name] = plan_dict + + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + if not self.plan_dict: + self.log_error(fl_ctx, "no plan_dict from client") + return False + + # must set sticky to True so other controllers can get it! + fl_ctx.set_prop("nnunet_plans", self.plan_dict, private=True, sticky=True) + return True + + +class nnUNetTrainProcessor(ResponseProcessor): + """ + A processor class for handling training responses in the nnUNet framework. + + Attributes + ---------- + val_summary_dict : dict + A dictionary to store validation summaries from clients. + Methods + ------- + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable + Creates task data for a given task name and FLContext. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool + Processes the response from a client for a given task name and FLContext. + final_process(fl_ctx: FLContext) -> bool + Final processing step to handle the collected validation summaries. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + self.val_summary_dict = {} + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + val_summary_dict = dxo.data + + if not val_summary_dict: + self.log_error(fl_ctx, f"No val_summary_dict found from client {client.name}") + return False + + self.val_summary_dict[client.name] = val_summary_dict + + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + if not self.val_summary_dict: + self.log_error(fl_ctx, "no val_summary_dict from client") + return False + + # must set sticky to True so other controllers can get it! + fl_ctx.set_prop("val_summary_dict", self.val_summary_dict, private=True, sticky=True) + return True + + +class nnUNetBundlePrepareProcessor(ResponseProcessor): + """ + A processor class for preparing nnUNet bundles in a federated learning context. + + Methods + ------- + __init__(): + Initializes the nnUNetBundlePrepareProcessor instance. + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable: + Creates task data for a given task name and federated learning context. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + Processes the response from a client and validates it. + final_process(fl_ctx: FLContext) -> bool: + Final processing step after all client responses have been processed. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + + return True diff --git a/requirements-dev.txt b/requirements-dev.txt index c9730ee651..a31b83a59e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -62,3 +62,5 @@ pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 onnx_graphsurgeon polygraphy +pyhocon +odict \ No newline at end of file From ca851cdbc22bae81b32d99090e9dfa2b81668f1c Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Tue, 25 Mar 2025 17:15:59 +0000 Subject: [PATCH 34/67] Add modality_list parameter to nnUNetExecutor and related functions --- monai/nvflare/nnunet_executor.py | 5 +++++ monai/nvflare/nvflare_generate_job_configs.py | 3 +++ monai/nvflare/nvflare_nnunet.py | 8 +++++++- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/monai/nvflare/nnunet_executor.py b/monai/nvflare/nnunet_executor.py index 12f21f678c..c00d2245aa 100644 --- a/monai/nvflare/nnunet_executor.py +++ b/monai/nvflare/nnunet_executor.py @@ -77,6 +77,8 @@ class nnUNetExecutor(Executor): Extra configurations for training. exclude_vars : list, optional List of variables to exclude. + modality_list : list, optional + List of modalities. Methods ------- @@ -119,6 +121,7 @@ def __init__( tracking_uri=None, mlflow_token=None, bundle_root=None, + modality_list=None, train_extra_configs=None, exclude_vars=None, ): @@ -143,6 +146,7 @@ def __init__( self.prepare_bundle_name = prepare_bundle_name self.bundle_root = bundle_root self.train_extra_configs = train_extra_configs + self.modality_list = modality_list def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: @@ -211,6 +215,7 @@ def prepare_dataset(self) -> Shareable: mlflow_token=self.mlflow_token, subfolder_suffix=self.subfolder_suffix, trainer_class_name=nnunet_trainer_name, + modality_list=self.modality_list, ) outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=data_list, meta={}) diff --git a/monai/nvflare/nvflare_generate_job_configs.py b/monai/nvflare/nvflare_generate_job_configs.py index b8c6e709d9..5d4f2ec226 100644 --- a/monai/nvflare/nvflare_generate_job_configs.py +++ b/monai/nvflare/nvflare_generate_job_configs.py @@ -135,6 +135,9 @@ def prepare_config(clients, experiment, root_dir, script_dir, nvflare_exec): ], } + if "modality_list" in experiment["modality_list"]: + client["executors"][0]["executor"]["args"]["modality_list"] = experiment["modality_list"] + if "subfolder_suffix" in clients[client_id]: client["executors"][0]["executor"]["args"]["subfolder_suffix"] = clients[client_id]["subfolder_suffix"] if "mlflow_token" in experiment: diff --git a/monai/nvflare/nvflare_nnunet.py b/monai/nvflare/nvflare_nnunet.py index 724c6c64df..d2255d7dca 100644 --- a/monai/nvflare/nvflare_nnunet.py +++ b/monai/nvflare/nvflare_nnunet.py @@ -305,6 +305,7 @@ def prepare_data_folder( experiment_name, client_name, dataset_format, + modality_list = None, tracking_uri=None, mlflow_token=None, subfolder_suffix=None, @@ -332,6 +333,8 @@ def prepare_data_folder( Format of the dataset. Supported formats are "subfolders", "decathlon", and "nnunet". tracking_uri : str, optional URI for MLflow tracking server. + modality_list : list, optional + List of modalities. Default is None. mlflow_token : str, optional Token for MLflow authentication. subfolder_suffix : str, optional @@ -438,9 +441,12 @@ def prepare_data_folder( os.makedirs(nnunet_root_dir, exist_ok=True) + if modality_list is None: + modality_list = [k for k in modality_dict.keys() if k != "label"] + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") data_src = { - "modality": [k for k in modality_dict.keys() if k != "label"], + "modality": modality_list, "dataset_name_or_id": dataset_name_or_id, "datalist": str(datalist_file), "dataroot": str(data_dir), From fee1bb06f20183c318dc310a2a95b82c3d9d4573 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Tue, 25 Mar 2025 19:55:15 +0000 Subject: [PATCH 35/67] Fix modality_list check in prepare_config and add debug print statement --- monai/nvflare/nvflare_generate_job_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/nvflare/nvflare_generate_job_configs.py b/monai/nvflare/nvflare_generate_job_configs.py index 5d4f2ec226..130f47e309 100644 --- a/monai/nvflare/nvflare_generate_job_configs.py +++ b/monai/nvflare/nvflare_generate_job_configs.py @@ -135,7 +135,7 @@ def prepare_config(clients, experiment, root_dir, script_dir, nvflare_exec): ], } - if "modality_list" in experiment["modality_list"]: + if "modality_list" in experiment: client["executors"][0]["executor"]["args"]["modality_list"] = experiment["modality_list"] if "subfolder_suffix" in clients[client_id]: From 1972504ea62afd9060f899e4743d3b745a0f3643 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Wed, 26 Mar 2025 14:35:26 +0000 Subject: [PATCH 36/67] Rename nnUNetMONAIModelWrapper to ModelnnUNetWrapper for consistency --- monai/apps/nnunet/__init__.py | 2 +- monai/apps/nnunet/nnunet_bundle.py | 180 +++++++++++++++++------------ 2 files changed, 108 insertions(+), 74 deletions(-) diff --git a/monai/apps/nnunet/__init__.py b/monai/apps/nnunet/__init__.py index 991de8d281..cdf96e0ce2 100644 --- a/monai/apps/nnunet/__init__.py +++ b/monai/apps/nnunet/__init__.py @@ -17,7 +17,7 @@ get_network_from_nnunet_plans, get_nnunet_monai_predictor, get_nnunet_trainer, - nnUNetMONAIModelWrapper, + ModelnnUNetWrapper ) from .nnunetv2_runner import nnUNetV2Runner from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 2b4b59a5c1..dba5a60bc1 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -16,7 +16,6 @@ import numpy as np import torch -from torch._dynamo import OptimizedModule from torch.backends import cudnn from monai.data.meta_tensor import MetaTensor @@ -25,7 +24,7 @@ join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") -__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "nnUNetMONAIModelWrapper"] +__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "convert_monai_bundle_to_nnunet","ModelnnUNetWrapper"] def get_nnunet_trainer( @@ -42,7 +41,7 @@ def get_nnunet_trainer( only_run_validation=False, disable_checkpointing=False, val_with_best=False, - device=torch.device("cuda"), + device="cuda", pretrained_model=None, ): """ @@ -50,25 +49,25 @@ def get_nnunet_trainer( The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, optimizer, loss function, DataLoader, etc. - ```python - from monai.apps import SupervisedTrainer - from monai.bundle.nnunet import get_nnunet_trainer + Example:: - dataset_name_or_id = 'Task101_PROSTATE' - fold = 0 - configuration = '3d_fullres' - nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) + from monai.apps import SupervisedTrainer + from monai.bundle.nnunet import get_nnunet_trainer - trainer = SupervisedTrainer( - device=nnunet_trainer.device, - max_epochs=nnunet_trainer.num_epochs, - train_data_loader=nnunet_trainer.dataloader_train, - network=nnunet_trainer.network, - optimizer=nnunet_trainer.optimizer, - loss_function=nnunet_trainer.loss_function, - epoch_length=nnunet_trainer.num_iterations_per_epoch, + dataset_name_or_id = 'Task009_Spleen' + fold = 0 + configuration = '3d_fullres' + nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) - ``` + trainer = SupervisedTrainer( + device=nnunet_trainer.device, + max_epochs=nnunet_trainer.num_epochs, + train_data_loader=nnunet_trainer.dataloader_train, + network=nnunet_trainer.network, + optimizer=nnunet_trainer.optimizer, + loss_function=nnunet_trainer.loss_function, + epoch_length=nnunet_trainer.num_iterations_per_epoch, + ) Parameters ---------- @@ -98,7 +97,7 @@ def get_nnunet_trainer( Whether to disable checkpointing. Default is False. val_with_best : bool, optional Whether to validate with the best model. Default is False. - device : torch.device, optional + device : str, optional The device to be used for training. Default is 'cuda'. pretrained_model : str, optional Path to the pretrained model file. @@ -130,7 +129,7 @@ def get_nnunet_trainer( trainer_class_name, plans_identifier, use_compressed_data, - device=device, + device=torch.device(device), ) if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing @@ -150,7 +149,7 @@ def get_nnunet_trainer( return nnunet_trainer -class nnUNetMONAIModelWrapper(torch.nn.Module): +class ModelnnUNetWrapper(torch.nn.Module): """ A wrapper class for nnUNet model integration with MONAI framework. The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference. @@ -163,16 +162,14 @@ class nnUNetMONAIModelWrapper(torch.nn.Module): The folder path where the model and related files are stored. model_name : str, optional The name of the model file, by default "model.pt". + Attributes ---------- - predictor : object - The predictor object used for inference. + predictor : nnUNetPredictor + The nnUNet predictor object used for inference. network_weights : torch.nn.Module The network weights of the model. - Methods - ------- - forward(x) - Perform forward pass and prediction on the input data. + Notes ----- This class integrates nnUNet model with MONAI framework by loading necessary configurations, @@ -184,13 +181,13 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): self.predictor = predictor model_training_output_dir = model_folder - use_folds = "0" + use_folds = ["0"] from nnunetv2.utilities.plans_handling.plans_handler import PlansManager - ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor - dataset_json = load_json(join(model_training_output_dir, "dataset.json")) - plans = load_json(join(model_training_output_dir, "plans.json")) + # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor + dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json")) + plans = load_json(join(Path(model_training_output_dir).parent, "plans.json")) plans_manager = PlansManager(plans) if isinstance(use_folds, str): @@ -198,9 +195,9 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): parameters = [] for i, f in enumerate(use_folds): - f = int(f) if f != "all" else f + f = str(f) if f != "all" else f checkpoint = torch.load( - join(model_training_output_dir, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") ) if i == 0: trainer_name = checkpoint["trainer_name"] @@ -254,32 +251,67 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): if ( ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) - and not isinstance(predictor.network, OptimizedModule) + # and not isinstance(predictor.network, OptimizedModule) ): print("Using torch.compile") - predictor.network = torch.compile(self.network) - ## End Block + # predictor.network = torch.compile(self.network) + # End Block self.network_weights = self.predictor.network - def forward(self, x): - if type(x) is tuple: # if batch is decollated (list of tensors) - input_files = [img.meta["filename_or_obj"][0] for img in x] - else: # if batch is collated - input_files = x.meta["filename_or_obj"] - if type(input_files) is str: - input_files = [input_files] + def forward(self, x: MetaTensor) -> MetaTensor: + """ + Forward pass for the nnUNet model. + + :no-index: + + Args: + x (MetaTensor): Input tensor. If the input is a tuple, + it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch. + + Returns: + MetaTensor: The output tensor with the same metadata as the input. + + Raises: + TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors. + + Notes: + - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple. + - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor. + - The filenames are used to generate predictions using the nnUNet predictor. + - The predictions are converted to torch tensors, with added batch and channel dimensions. + - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. + """ + # if isinstance(x, tuple): # if batch is decollated (list of tensors) + # properties_or_list_of_properties = [] + # image_or_list_of_images = [] + + # for img in x: + # if isinstance(img, MetaTensor): + # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()}) + # image_or_list_of_images.append(img.cpu().numpy()[0,:]) + # else: + # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") + + # else: # if batch is collated + if isinstance(x, MetaTensor): + if "pixdim" in x.meta: + properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} + else: + properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} + else: + raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") + + image_or_list_of_images = x.cpu().numpy()[0, :] # input_files should be a list of file paths, one per modality - prediction_output = self.predictor.predict_from_files( - [input_files], + prediction_output = self.predictor.predict_from_list_of_npy_arrays( + image_or_list_of_images, None, + properties_or_list_of_properties, + truncated_ofname=None, save_probabilities=False, - overwrite=True, - num_processes_preprocessing=2, + num_processes=2, num_processes_segmentation_export=2, - folder_with_segs_from_prev_stage=None, - num_parts=1, - part_id=0, ) # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax @@ -288,35 +320,36 @@ def forward(self, x): out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension - if type(x) is tuple: - return MetaTensor(out_tensor, meta=x[0].meta) - else: - return MetaTensor(out_tensor, meta=x.meta) + # if type(x) is tuple: + # return MetaTensor(out_tensor, meta=x[0].meta) + # else: + return MetaTensor(out_tensor, meta=x.meta) def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): """ - Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor. + Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. The model folder should contain the following files, created during training: - - dataset.json: from the nnUNet results folder. - - plans.json: from the nnUNet results folder. - - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration - (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`). - - model.pt: The checkpoint file containing the model weights. + + - dataset.json: from the nnUNet results folder + - plans.json: from the nnUNet results folder + - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration + - model.pt: The checkpoint file containing the model weights. The returned wrapper object can be used for inference with MONAI framework: - ```python - from monai.bundle.nnunet import get_nnunet_monai_predictor - model_folder = 'path/to/monai_bundle/model' - model_name = 'model.pt' - wrapper = get_nnunet_monai_predictor(model_folder, model_name) + Example:: + + from monai.bundle.nnunet import get_nnunet_monai_predictor + + model_folder = 'path/to/monai_bundle/model' + model_name = 'model.pt' + wrapper = get_nnunet_monai_predictor(model_folder, model_name) - # Perform inference - input_data = ... - output = wrapper(input_data) + # Perform inference + input_data = ... + output = wrapper(input_data) - ``` Parameters ---------- @@ -343,7 +376,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): allow_tqdm=True, ) # initializes the network architecture, loads the checkpoint - wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name) + wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name) return wrapper @@ -396,13 +429,14 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth")) + Path(bundle_root_folder).joinpath("models", f"fold_{fold}").mkdir(parents=True, exist_ok=True) monai_last_checkpoint = {} monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"] - torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", "model.pt")) + torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "model.pt")) monai_best_checkpoint = {} monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"] - torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", "best_model.pt")) + torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "best_model.pt")) if not os.path.exists(os.path.join(bundle_root_folder, "models", "plans.json")): shutil.copy( From 5c633f21ffadd2fe224613c890e9f7331c6eab58 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Fri, 28 Mar 2025 08:18:52 +0000 Subject: [PATCH 37/67] --- monai/apps/nnunet/nnunet_bundle.py | 217 ++++++++++++++--------------- 1 file changed, 101 insertions(+), 116 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index dba5a60bc1..1581e325f1 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -21,6 +21,7 @@ from monai.data.meta_tensor import MetaTensor from monai.utils import optional_import +from typing import Union, Optional join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") @@ -28,22 +29,18 @@ def get_nnunet_trainer( - dataset_name_or_id, - configuration, - fold, - trainer_class_name="nnUNetTrainer", - plans_identifier="nnUNetPlans", - pretrained_weights=None, - num_gpus=1, - use_compressed_data=False, - export_validation_probabilities=False, - continue_training=False, - only_run_validation=False, - disable_checkpointing=False, - val_with_best=False, - device="cuda", - pretrained_model=None, -): + dataset_name_or_id: Union[str, int], + configuration: str, + fold: Union[int, str], + trainer_class_name: str = "nnUNetTrainer", + plans_identifier: str = "nnUNetPlans", + use_compressed_data: bool = False, + continue_training: bool = False, + only_run_validation: bool = False, + disable_checkpointing: bool = False, + device: str = "cuda", + pretrained_model: Optional[str] = None, +) -> object: """ Get the nnUNet trainer instance based on the provided configuration. The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, @@ -81,29 +78,22 @@ def get_nnunet_trainer( The class name of the trainer to be used. Default is 'nnUNetTrainer'. plans_identifier : str, optional Identifier for the plans to be used. Default is 'nnUNetPlans'. - pretrained_weights : str, optional - Path to the pretrained weights file. - num_gpus : int, optional - Number of GPUs to be used. Default is 1. use_compressed_data : bool, optional Whether to use compressed data. Default is False. - export_validation_probabilities : bool, optional - Whether to export validation probabilities. Default is False. continue_training : bool, optional Whether to continue training from a checkpoint. Default is False. only_run_validation : bool, optional Whether to only run validation. Default is False. disable_checkpointing : bool, optional Whether to disable checkpointing. Default is False. - val_with_best : bool, optional - Whether to validate with the best model. Default is False. device : str, optional The device to be used for training. Default is 'cuda'. - pretrained_model : str, optional + pretrained_model : Optional[str], optional Path to the pretrained model file. + Returns ------- - nnunet_trainer + nnunet_trainer : object The nnUNet trainer instance. """ # From nnUNet/nnunetv2/run/run_training.py#run_training @@ -117,36 +107,34 @@ def get_nnunet_trainer( ) raise e - if int(num_gpus) > 1: - ... # Disable for now - else: - from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint - - nnunet_trainer = get_trainer_from_args( - str(dataset_name_or_id), - configuration, - fold, - trainer_class_name, - plans_identifier, - use_compressed_data, - device=torch.device(device), - ) - if disable_checkpointing: - nnunet_trainer.disable_checkpointing = disable_checkpointing - assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy." + from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint + + nnunet_trainer = get_trainer_from_args( + str(dataset_name_or_id), + configuration, + fold, + trainer_class_name, + plans_identifier, + use_compressed_data, + device=torch.device(device), + ) + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy." - maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) - nnunet_trainer.on_train_start() # Added to Initialize Trainer - if torch.cuda.is_available(): - cudnn.deterministic = False - cudnn.benchmark = True + maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation) + nnunet_trainer.on_train_start() # Added to Initialize Trainer + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True - if pretrained_model is not None: - state_dict = torch.load(pretrained_model) - if "network_weights" in state_dict: - nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) - return nnunet_trainer + if pretrained_model is not None: + state_dict = torch.load(pretrained_model) + if "network_weights" in state_dict: + nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) + return nnunet_trainer class ModelnnUNetWrapper(torch.nn.Module): @@ -176,12 +164,11 @@ class ModelnnUNetWrapper(torch.nn.Module): restoring network architecture, and setting up the predictor for inference. """ - def __init__(self, predictor, model_folder, model_name="model.pt"): + def __init__(self, predictor: object, model_folder: str, model_name: str = "model.pt"): super().__init__() self.predictor = predictor model_training_output_dir = model_folder - use_folds = ["0"] from nnunetv2.utilities.plans_handling.plans_handler import PlansManager @@ -190,31 +177,26 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): plans = load_json(join(Path(model_training_output_dir).parent, "plans.json")) plans_manager = PlansManager(plans) - if isinstance(use_folds, str): - use_folds = [use_folds] - parameters = [] - for i, f in enumerate(use_folds): - f = str(f) if f != "all" else f - checkpoint = torch.load( - join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + + checkpoint = torch.load( + join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + ) + trainer_name = checkpoint["trainer_name"] + configuration_name = checkpoint["init_args"]["configuration"] + inference_allowed_mirroring_axes = ( + checkpoint["inference_allowed_mirroring_axes"] + if "inference_allowed_mirroring_axes" in checkpoint.keys() + else None + ) + if Path(model_training_output_dir).joinpath(model_name).is_file(): + monai_checkpoint = torch.load( + join(model_training_output_dir, model_name), map_location=torch.device("cpu") ) - if i == 0: - trainer_name = checkpoint["trainer_name"] - configuration_name = checkpoint["init_args"]["configuration"] - inference_allowed_mirroring_axes = ( - checkpoint["inference_allowed_mirroring_axes"] - if "inference_allowed_mirroring_axes" in checkpoint.keys() - else None - ) - if Path(model_training_output_dir).joinpath(f"fold_{f}", model_name).is_file(): - monai_checkpoint = torch.load( - join(model_training_output_dir, model_name), map_location=torch.device("cpu") - ) - if "network_weights" in monai_checkpoint.keys(): - parameters.append(monai_checkpoint["network_weights"]) - else: - parameters.append(monai_checkpoint) + if "network_weights" in monai_checkpoint.keys(): + parameters.append(monai_checkpoint["network_weights"]) + else: + parameters.append(monai_checkpoint) configuration_manager = plans_manager.get_configuration(configuration_name) # restore network @@ -251,10 +233,8 @@ def __init__(self, predictor, model_folder, model_name="model.pt"): if ( ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) - # and not isinstance(predictor.network, OptimizedModule) ): print("Using torch.compile") - # predictor.network = torch.compile(self.network) # End Block self.network_weights = self.predictor.network @@ -281,21 +261,12 @@ def forward(self, x: MetaTensor) -> MetaTensor: - The predictions are converted to torch tensors, with added batch and channel dimensions. - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. """ - # if isinstance(x, tuple): # if batch is decollated (list of tensors) - # properties_or_list_of_properties = [] - # image_or_list_of_images = [] - - # for img in x: - # if isinstance(img, MetaTensor): - # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()}) - # image_or_list_of_images.append(img.cpu().numpy()[0,:]) - # else: - # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") - - # else: # if batch is collated if isinstance(x, MetaTensor): if "pixdim" in x.meta: properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} + elif "affine" in x.meta: + spacing = [abs(x.meta['affine'][0][0].item()), abs(x.meta['affine'][1][1].item()), abs(x.meta['affine'][2][2].item())] + properties_or_list_of_properties = {"spacing": spacing} else: properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} else: @@ -320,13 +291,10 @@ def forward(self, x: MetaTensor) -> MetaTensor: out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension - # if type(x) is tuple: - # return MetaTensor(out_tensor, meta=x[0].meta) - # else: return MetaTensor(out_tensor, meta=x.meta) -def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): +def get_nnunet_monai_predictor(model_folder: str, model_name: str = "model.pt") -> ModelnnUNetWrapper: """ Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. The model folder should contain the following files, created during training: @@ -360,7 +328,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): Returns ------- - nnUNetMONAIModelWrapper + ModelnnUNetWrapper A wrapper object that contains the nnUNetPredictor and the loaded model. """ @@ -380,7 +348,9 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): return wrapper -def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): +def convert_nnunet_to_monai_bundle( + nnunet_config: dict, bundle_root_folder: str, fold: int = 0 +) -> None: """ Convert nnUNet model checkpoints and configuration to MONAI bundle format. @@ -450,7 +420,13 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): ) -def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model_ckpt=None, model_key_in_ckpt="model"): +def get_network_from_nnunet_plans( + plans_file: str, + dataset_file: str, + configuration: str, + model_ckpt: Optional[str] = None, + model_key_in_ckpt: str = "model" +) -> torch.nn.Module: """ Load and initialize a neural network based on nnUNet plans and configuration. @@ -462,7 +438,7 @@ def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model Path to the JSON file containing the dataset information. configuration : str The configuration name to be used from the plans. - model_ckpt : str, optional + model_ckpt : Optional[str], optional Path to the model checkpoint file. If None, the network is returned without loading weights (default is None). model_key_in_ckpt : str, optional The key in the checkpoint file that contains the model state dictionary (default is "model"). @@ -505,7 +481,11 @@ def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model return network -def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0): +def convert_monai_bundle_to_nnunet( + nnunet_config: dict, + bundle_root_folder: str, + fold: int = 0 +) -> None: """ Convert a MONAI bundle to nnU-Net format. @@ -527,8 +507,8 @@ def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0): """ from odict import odict - nnunet_trainer = "nnUNetTrainer" - nnunet_plans = "nnUNetPlans" + nnunet_trainer: str = "nnUNetTrainer" + nnunet_plans: str = "nnUNetPlans" if "nnunet_trainer" in nnunet_config: nnunet_trainer = nnunet_config["nnunet_trainer"] @@ -539,8 +519,13 @@ def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0): from nnunetv2.training.logging.nnunet_logger import nnUNetLogger from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name - def subfiles(folder, join: bool = True, prefix: str = None, suffix: str = None, sort: bool = True): - + def subfiles( + folder: str, + join: bool = True, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + sort: bool = True + ) -> list[str]: if join: l = os.path.join # noqa: E741 else: @@ -556,42 +541,42 @@ def subfiles(folder, join: bool = True, prefix: str = None, suffix: str = None, res.sort() return res - nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath( + nnunet_model_folder: Path = Path(os.environ["nnUNet_results"]).joinpath( maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]), f"{nnunet_trainer}__{nnunet_plans}__3d_fullres", ) - nnunet_preprocess_model_folder = Path(os.environ["nnUNet_preprocessed"]).joinpath( + nnunet_preprocess_model_folder: Path = Path(os.environ["nnUNet_preprocessed"]).joinpath( maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]) ) Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True) - nnunet_checkpoint = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") - latest_checkpoints = subfiles( + nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") + latest_checkpoints: list[str] = subfiles( Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True, join=False ) - epochs = [] + epochs: list[int] = [] for latest_checkpoint in latest_checkpoints: epochs.append(int(latest_checkpoint[len("checkpoint_epoch=") : -len(".pt")])) epochs.sort() - final_epoch = epochs[-1] - monai_last_checkpoint = torch.load(f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt") + final_epoch: int = epochs[-1] + monai_last_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt") - best_checkpoints = subfiles( + best_checkpoints: list[str] = subfiles( Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_key_metric", sort=True, join=False, ) - key_metrics = [] + key_metrics: list[str] = [] for best_checkpoint in best_checkpoints: key_metrics.append(str(best_checkpoint[len("checkpoint_key_metric=") : -len(".pt")])) key_metrics.sort() - best_key_metric = key_metrics[-1] - monai_best_checkpoint = torch.load( + best_key_metric: str = key_metrics[-1] + monai_best_checkpoint: dict = torch.load( f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt" ) From 052ef648f4e22a77a358a794cd22a53aef760345 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Fri, 28 Mar 2025 08:19:36 +0000 Subject: [PATCH 38/67] Add original_dataset_name to nnunet_plans in preprocess function --- monai/nvflare/nvflare_nnunet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/nvflare/nvflare_nnunet.py b/monai/nvflare/nvflare_nnunet.py index d2255d7dca..72dc062ccd 100644 --- a/monai/nvflare/nvflare_nnunet.py +++ b/monai/nvflare/nvflare_nnunet.py @@ -198,6 +198,7 @@ def preprocess(nnunet_root_dir, dataset_name_or_id, nnunet_plans_file_path=None, if nnunet_plans_file_path is not None: with open(nnunet_plans_file_path, "r") as f: nnunet_plans = json.load(f) + nnunet_plans["original_dataset_name"] = nnunet_plans["dataset_name"] nnunet_plans["dataset_name"] = dataset_name json.dump( nnunet_plans, From 1c41164fe73a2612b70caaca7ce97a8ccad30f94 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 11:27:46 +0000 Subject: [PATCH 39/67] Remove unused nvflare module files and restore polygraphy in requirements-dev.txt --- monai/nvflare/__init__.py | 10 - monai/nvflare/json_generator.py | 179 --- monai/nvflare/nnunet_executor.py | 334 ----- monai/nvflare/nvflare_generate_job_configs.py | 1085 ----------------- monai/nvflare/nvflare_nnunet.py | 695 ----------- monai/nvflare/response_processor.py | 342 ------ requirements-dev.txt | 4 +- 7 files changed, 1 insertion(+), 2648 deletions(-) delete mode 100644 monai/nvflare/__init__.py delete mode 100644 monai/nvflare/json_generator.py delete mode 100644 monai/nvflare/nnunet_executor.py delete mode 100644 monai/nvflare/nvflare_generate_job_configs.py delete mode 100644 monai/nvflare/nvflare_nnunet.py delete mode 100644 monai/nvflare/response_processor.py diff --git a/monai/nvflare/__init__.py b/monai/nvflare/__init__.py deleted file mode 100644 index 1e97f89407..0000000000 --- a/monai/nvflare/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/monai/nvflare/json_generator.py b/monai/nvflare/json_generator.py deleted file mode 100644 index 9326a35837..0000000000 --- a/monai/nvflare/json_generator.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import json -import os.path - -from nvflare.apis.event_type import EventType -from nvflare.apis.fl_context import FLContext -from nvflare.widgets.widget import Widget - - -class PrepareJsonGenerator(Widget): - """ - A widget class to prepare and generate a JSON file containing data preparation configurations. - - Parameters - ---------- - results_dir : str, optional - The directory where the results will be stored (default is "prepare"). - json_file_name : str, optional - The name of the JSON file to be generated (default is "data_dict.json"). - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events during the federated learning process. Clears the data preparation configuration - at the start of a run and saves the configuration to a JSON file at the end of a run. - """ - - def __init__(self, results_dir="prepare", json_file_name="data_dict.json"): - super(PrepareJsonGenerator, self).__init__() - - self._results_dir = results_dir - self._data_prepare_config = {} - self._json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self._data_prepare_config.clear() - elif event_type == EventType.END_RUN: - self._data_prepare_config = fl_ctx.get_prop("client_data_dict", None) - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - data_prepare_res_dir = os.path.join(run_dir, self._results_dir) - if not os.path.exists(data_prepare_res_dir): - os.makedirs(data_prepare_res_dir) - - res_file_path = os.path.join(data_prepare_res_dir, self._json_file_name) - with open(res_file_path, "w") as f: - json.dump(self._data_prepare_config, f) - - -class nnUNetPackageReportJsonGenerator(Widget): - """ - A class to generate JSON reports for nnUNet package. - - Parameters - ---------- - results_dir : str, optional - Directory where the report will be saved (default is "package_report"). - json_file_name : str, optional - Name of the JSON file to save the report (default is "package_report.json"). - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events to clear the report at the start of a run and save the report at the end of a run. - """ - - def __init__(self, results_dir="package_report", json_file_name="package_report.json"): - super(nnUNetPackageReportJsonGenerator, self).__init__() - - self._results_dir = results_dir - self._report = {} - self._json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self._report.clear() - elif event_type == EventType.END_RUN: - datasets = fl_ctx.get_prop("package_report", None) - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - cross_val_res_dir = os.path.join(run_dir, self._results_dir) - if not os.path.exists(cross_val_res_dir): - os.makedirs(cross_val_res_dir) - - res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) - with open(res_file_path, "w") as f: - json.dump(datasets, f) - - -class nnUNetPlansJsonGenerator(Widget): - """ - A class to generate JSON files for nnUNet plans. - - Parameters - ---------- - results_dir : str, optional - Directory where the preprocessing results will be stored (default is "nnUNet_preprocessing"). - json_file_name : str, optional - Name of the JSON file to be generated (default is "nnUNetPlans.json"). - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves - the plans to a JSON file at the end of a run. - """ - - def __init__(self, results_dir="nnUNet_preprocessing", json_file_name="nnUNetPlans.json"): - - super(nnUNetPlansJsonGenerator, self).__init__() - - self._results_dir = results_dir - self._nnUNetPlans = {} - self._json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self._nnUNetPlans.clear() - elif event_type == EventType.END_RUN: - datasets = fl_ctx.get_prop("nnunet_plans", None) - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - cross_val_res_dir = os.path.join(run_dir, self._results_dir) - if not os.path.exists(cross_val_res_dir): - os.makedirs(cross_val_res_dir) - - res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) - with open(res_file_path, "w") as f: - json.dump(datasets, f) - - -class nnUNetValSummaryJsonGenerator(Widget): - """ - A widget to generate a JSON summary for nnUNet validation results. - - Parameters - ---------- - results_dir : str, optional - Directory where the nnUNet training results are stored (default is "nnUNet_train"). - json_file_name : str, optional - Name of the JSON file to save the validation summary (default is "val_summary.json"). - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves - the validation summary to a JSON file at the end of a run. - """ - - def __init__(self, results_dir="nnUNet_train", json_file_name="val_summary.json"): - - super(nnUNetValSummaryJsonGenerator, self).__init__() - - self._results_dir = results_dir - self._nnUNetPlans = {} - self._json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self._nnUNetPlans.clear() - elif event_type == EventType.END_RUN: - datasets = fl_ctx.get_prop("val_summary_dict", None) - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - cross_val_res_dir = os.path.join(run_dir, self._results_dir) - if not os.path.exists(cross_val_res_dir): - os.makedirs(cross_val_res_dir) - - res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) - with open(res_file_path, "w") as f: - json.dump(datasets, f) diff --git a/monai/nvflare/nnunet_executor.py b/monai/nvflare/nnunet_executor.py deleted file mode 100644 index c00d2245aa..0000000000 --- a/monai/nvflare/nnunet_executor.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import subprocess -import sys -from pathlib import Path - -from nvflare.apis.dxo import DXO, DataKind -from nvflare.apis.event_type import EventType -from nvflare.apis.executor import Executor -from nvflare.apis.fl_constant import ReturnCode -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable, make_reply -from nvflare.apis.signal import Signal - -from monai.nvflare.nvflare_nnunet import ( # check_host_config, - check_packages, - plan_and_preprocess, - prepare_bundle, - prepare_data_folder, - preprocess, - train, -) - - -class nnUNetExecutor(Executor): - """ - nnUNetExecutor is a class that handles the execution of various tasks related to nnUNet training and preprocessing - within the NVFlare framework. - - Parameters - ---------- - data_dir : str, optional - Directory where the data is stored. - modality_dict : dict, optional - Dictionary containing modality information. - prepare_task_name : str, optional - Name of the task for preparing the dataset. - check_client_packages_task_name : str, optional - Name of the task for checking client packages. - plan_and_preprocess_task_name : str, optional - Name of the task for planning and preprocessing. - preprocess_task_name : str, optional - Name of the task for preprocessing. - training_task_name : str, optional - Name of the task for training. - prepare_bundle_name : str, optional - Name of the task for preparing the bundle. - subfolder_suffix : str, optional - Suffix for subfolders. - dataset_format : str, optional - Format of the dataset, default is "subfolders". - patient_id_in_file_identifier : bool, optional - Whether patient ID is in file identifier, default is True. - nnunet_config : dict, optional - Configuration dictionary for nnUNet. - nnunet_root_folder : str, optional - Root folder for nnUNet. - client_name : str, optional - Name of the client. - tracking_uri : str, optional - URI for tracking. - mlflow_token : str, optional - Token for MLflow. - bundle_root : str, optional - Root directory for the bundle. - train_extra_configs : dict, optional - Extra configurations for training. - exclude_vars : list, optional - List of variables to exclude. - modality_list : list, optional - List of modalities. - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events triggered during the federated learning process. - initialize(fl_ctx: FLContext) - Initializes the executor with the given federated learning context. - execute(task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable - Executes the specified task. - prepare_dataset() -> Shareable - Prepares the dataset for training. - check_packages_installed() -> Shareable - Checks if the required packages are installed. - plan_and_preprocess() -> Shareable - Plans and preprocesses the dataset. - preprocess() -> Shareable - Preprocesses the dataset. - train() -> Shareable - Trains the model. - prepare_bundle() -> Shareable - Prepares the bundle for deployment. - """ - - def __init__( - self, - data_dir=None, - modality_dict=None, - prepare_task_name="prepare", - check_client_packages_task_name="check_client_packages", - plan_and_preprocess_task_name="plan_and_preprocess", - preprocess_task_name="preprocess", - training_task_name="train", - prepare_bundle_name="prepare_bundle", - subfolder_suffix=None, - dataset_format="subfolders", - patient_id_in_file_identifier=True, - nnunet_config=None, - nnunet_root_folder=None, - client_name=None, - tracking_uri=None, - mlflow_token=None, - bundle_root=None, - modality_list=None, - train_extra_configs=None, - exclude_vars=None, - ): - super().__init__() - - self.exclude_vars = exclude_vars - self.prepare_task_name = prepare_task_name - self.data_dir = data_dir - self.subfolder_suffix = subfolder_suffix - self.patient_id_in_file_identifier = patient_id_in_file_identifier - self.dataset_format = dataset_format - self.modality_dict = modality_dict - self.nnunet_config = nnunet_config - self.nnunet_root_folder = nnunet_root_folder - self.client_name = client_name - self.tracking_uri = tracking_uri - self.mlflow_token = mlflow_token - self.check_client_packages_task_name = check_client_packages_task_name - self.plan_and_preprocess_task_name = plan_and_preprocess_task_name - self.preprocess_task_name = preprocess_task_name - self.training_task_name = training_task_name - self.prepare_bundle_name = prepare_bundle_name - self.bundle_root = bundle_root - self.train_extra_configs = train_extra_configs - self.modality_list = modality_list - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self.initialize(fl_ctx) - - def initialize(self, fl_ctx: FLContext): - self.run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - self.root_dir = fl_ctx.get_engine().get_workspace().root_dir - self.custom_app_dir = fl_ctx.get_engine().get_workspace().get_app_custom_dir(fl_ctx.get_job_id()) - - with open("init_logfile_out.log", "w") as f_o: - with open("init_logfile_err.log", "w") as f_e: - subprocess.call( - [ - sys.executable, - "-m", - "pip", - "install", - "--user", - "-r", - str(Path(self.custom_app_dir).joinpath("requirements.txt")), - ], - stdout=f_o, - stderr=f_e, - ) - - def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: - self.run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - self.root_dir = fl_ctx.get_engine().get_workspace().root_dir - self.custom_app_dir = fl_ctx.get_engine().get_workspace().get_app_custom_dir(fl_ctx.get_job_id()) - try: - if task_name == self.prepare_task_name: - return self.prepare_dataset() - elif task_name == self.check_client_packages_task_name: - return self.check_packages_installed() - elif task_name == self.plan_and_preprocess_task_name: - return self.plan_and_preprocess() - elif task_name == self.preprocess_task_name: - return self.preprocess() - elif task_name == self.training_task_name: - return self.train() - elif task_name == self.prepare_bundle_name: - return self.prepare_bundle() - else: - return make_reply(ReturnCode.TASK_UNKNOWN) - except Exception as e: - self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.") - return make_reply(ReturnCode.EXECUTION_EXCEPTION) - - def prepare_dataset(self) -> Shareable: - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - data_list = prepare_data_folder( - data_dir=self.data_dir, - nnunet_root_dir=self.nnunet_root_folder, - dataset_name_or_id=self.nnunet_config["dataset_name_or_id"], - modality_dict=self.modality_dict, - experiment_name=self.nnunet_config["experiment_name"], - client_name=self.client_name, - dataset_format=self.dataset_format, - patient_id_in_file_identifier=self.patient_id_in_file_identifier, - tracking_uri=self.tracking_uri, - mlflow_token=self.mlflow_token, - subfolder_suffix=self.subfolder_suffix, - trainer_class_name=nnunet_trainer_name, - modality_list=self.modality_list, - ) - - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=data_list, meta={}) - return outgoing_dxo.to_shareable() - - def check_packages_installed(self): - packages = [ - "nvflare", - # {"package_name":'pymaia-learn',"import_name":"PyMAIA"}, - "torch", - "monai", - "numpy", - "nnunetv2", - ] - package_report = check_packages(packages) - - # host_config = check_host_config() - # package_report.update(host_config) - - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=package_report, meta={}) - - return outgoing_dxo.to_shareable() - - def plan_and_preprocess(self): - if "nnunet_plans" not in self.nnunet_config: - nnunet_plans_name = "nnUNetPlans" - else: - nnunet_plans_name = self.nnunet_config["nnunet_plans"] - - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - nnunet_plans = plan_and_preprocess( - self.nnunet_root_folder, - self.nnunet_config["dataset_name_or_id"], - self.client_name, - self.nnunet_config["experiment_name"], - self.tracking_uri, - nnunet_plans_name=nnunet_plans_name, - trainer_class_name=nnunet_trainer_name, - ) - - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=nnunet_plans, meta={}) - return outgoing_dxo.to_shareable() - - def preprocess(self): - if "nnunet_plans" not in self.nnunet_config: - nnunet_plans_name = "nnUNetPlans" - else: - nnunet_plans_name = self.nnunet_config["nnunet_plans"] - - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - nnunet_plans = preprocess( - self.nnunet_root_folder, - self.nnunet_config["dataset_name_or_id"], - nnunet_plans_file_path=Path(self.custom_app_dir).joinpath(f"{nnunet_plans_name}.json"), - trainer_class_name=nnunet_trainer_name, - ) - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=nnunet_plans, meta={}) - return outgoing_dxo.to_shareable() - - def train(self): - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - if "nnunet_plans" not in self.nnunet_config: - nnunet_plans_name = "nnUNetPlans" - else: - nnunet_plans_name = self.nnunet_config["nnunet_plans"] - - validation_summary = train( - self.nnunet_root_folder, - trainer_class_name=nnunet_trainer_name, - fold=0, - experiment_name=self.nnunet_config["experiment_name"], - client_name=self.client_name, - tracking_uri=self.tracking_uri, - nnunet_plans_name=nnunet_plans_name, - dataset_name_or_id=self.nnunet_config["dataset_name_or_id"], - run_with_bundle=True if self.bundle_root is not None else False, - bundle_root=self.bundle_root, - ) - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=validation_summary, meta={}) - return outgoing_dxo.to_shareable() - - def prepare_bundle(self): - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - if "nnunet_plans" not in self.nnunet_config: - nnunet_plans_name = "nnUNetPlans" - else: - nnunet_plans_name = self.nnunet_config["nnunet_plans"] - - bundle_config = { - "bundle_root": self.bundle_root, - "tracking_uri": self.tracking_uri, - "mlflow_experiment_name": "FedLearning-" + self.nnunet_config["experiment_name"], - "mlflow_run_name": self.client_name, - "nnunet_plans_identifier": nnunet_plans_name, - "nnunet_trainer_class_name": nnunet_trainer_name, - } - - prepare_bundle(bundle_config, self.train_extra_configs) - - return make_reply(ReturnCode.OK) diff --git a/monai/nvflare/nvflare_generate_job_configs.py b/monai/nvflare/nvflare_generate_job_configs.py deleted file mode 100644 index 130f47e309..0000000000 --- a/monai/nvflare/nvflare_generate_job_configs.py +++ /dev/null @@ -1,1085 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import subprocess -from pathlib import Path - -import yaml -from pyhocon import ConfigFactory -from pyhocon.converter import HOCONConverter - - -def prepare_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Prepare configuration files for nnUNet dataset preparation using NVFlare. - - Parameters - ---------- - clients : dict - Dictionary containing client-specific configurations. Each key is a client ID and the value is a dictionary - with the following keys: - - "data_dir": str, path to the client's data directory. - - "patient_id_in_file_identifier": str, identifier for patient ID in file. - - "modality_dict": dict, dictionary mapping modalities. - - "dataset_format": str, format of the dataset. - - "nnunet_root_folder": str, path to the nnUNet root folder. - - "client_name": str, name of the client. - - "subfolder_suffix": str, optional, suffix for subfolders. - experiment : dict - Dictionary containing experiment-specific configurations with the following keys: - - "dataset_name_or_id": str, name or ID of the dataset. - - "experiment_name": str, name of the experiment. - - "tracking_uri": str, URI for tracking. - - "mlflow_token": str, optional, token for MLflow. - root_dir : str - Root directory where the configuration files will be generated. - script_dir : str - Directory containing the scripts. - nvflare_exec : str - Path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "prepare" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Prepare nnUNet Dataset", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPrepareProcessor", "args": {}}, - {"id": "json_generator", "path": "monai.nvflare.json_generator.PrepareJsonGenerator", "args": {}}, - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "data_dir": clients[client_id]["data_dir"], - "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], - "modality_dict": clients[client_id]["modality_dict"], - "dataset_format": clients[client_id]["dataset_format"], - "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], - "nnunet_config": { - "dataset_name_or_id": experiment["dataset_name_or_id"], - "experiment_name": experiment["experiment_name"], - }, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "modality_list" in experiment: - client["executors"][0]["executor"]["args"]["modality_list"] = experiment["modality_list"] - - if "subfolder_suffix" in clients[client_id]: - client["executors"][0]["executor"]["args"]["subfolder_suffix"] = clients[client_id]["subfolder_suffix"] - if "mlflow_token" in experiment: - client["executors"][0]["executor"]["args"]["mlflow_token"] = experiment["mlflow_token"] - - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def check_client_packages_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generate job configuration files for checking client packages in an NVFlare experiment. - - Parameters - ---------- - clients : dict - A dictionary where keys are client IDs and values are client details. - experiment : str - The name of the experiment. - root_dir : str - The root directory where the configuration files will be generated. - script_dir : str - The directory containing the necessary scripts for NVFlare. - nvflare_exec : str - The NVFlare executable path. - - Returns - ------- - None - """ - task_name = "check_client_packages" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = { - "description": "Check Python Packages and Report", - "client_category": "Executor", - "controller_type": "server", - } - - meta = { - "name": f"{task_name}", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "nnunet_processor", - "path": "monai.nvflare.response_processor.nnUNetPackageReportProcessor", - "args": {}, - }, - { - "id": "json_generator", - "path": "monai.nvflare.json_generator.nnUNetPackageReportJsonGenerator", - "args": {}, - }, - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - {"tasks": [task_name], "executor": {"path": "monai.nvflare.nnunet_executor.nnUNetExecutor", "args": {}}} - ], - } - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def plan_and_preprocess_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generates and writes configuration files for the plan and preprocess task in the nnUNet experiment. - - Parameters - ---------- - clients : dict - A dictionary containing client-specific configurations. Each key is a client ID, and the value is - another dictionary with client-specific settings. - experiment : dict - A dictionary containing experiment-specific configurations such as dataset name, experiment name, - tracking URI, and optional nnUNet plans and trainer. - root_dir : str - The root directory where the configuration files will be generated. - script_dir : str - The directory containing the scripts to be used in the NVFlare job. - nvflare_exec : str - The path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "plan_and_preprocess" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Plan and Preprocess nnUNet", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPlanProcessor", "args": {}}, - {"id": "json_generator", "path": "monai.nvflare.json_generator.nnUNetPlansJsonGenerator", "args": {}}, - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "data_dir": clients[client_id]["data_dir"], - "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], - "modality_dict": clients[client_id]["modality_dict"], - "dataset_format": clients[client_id]["dataset_format"], - "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], - "nnunet_config": { - "dataset_name_or_id": experiment["dataset_name_or_id"], - "experiment_name": experiment["experiment_name"], - }, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def preprocess_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generate job configuration files for the preprocessing task in NVFlare. - - Parameters - ---------- - clients : dict - A dictionary containing client-specific configurations. Each key is a client ID, and the value is a dictionary - with the following keys: - - 'data_dir': str, path to the client's data directory. - - 'patient_id_in_file_identifier': str, identifier for patient ID in the file. - - 'modality_dict': dict, dictionary mapping modalities. - - 'dataset_format': str, format of the dataset. - - 'nnunet_root_folder': str, root folder for nnUNet. - - 'client_name': str, name of the client. - experiment : dict - A dictionary containing experiment-specific configurations with the following keys: - - 'dataset_name_or_id': str, name or ID of the dataset. - - 'experiment_name': str, name of the experiment. - - 'tracking_uri': str, URI for tracking. - - 'nnunet_plans' (optional): str, nnUNet plans. - - 'nnunet_trainer' (optional): str, nnUNet trainer. - root_dir : str - The root directory where the configuration files will be generated. - script_dir : str - The directory containing the scripts to be used in the job. - nvflare_exec : str - The NVFlare executable to be used for creating the job. - - Returns - ------- - None - """ - task_name = "preprocess" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Preprocess nnUNet", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPlanProcessor", "args": {}} - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "data_dir": clients[client_id]["data_dir"], - "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], - "modality_dict": clients[client_id]["modality_dict"], - "dataset_format": clients[client_id]["dataset_format"], - "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], - "nnunet_config": { - "dataset_name_or_id": experiment["dataset_name_or_id"], - "experiment_name": experiment["experiment_name"], - }, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def train_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generate training configuration files for nnUNet using NVFlare. - - Parameters - ---------- - clients : dict - Dictionary containing client-specific configurations. Each key is a client ID, and the value is a dictionary - with the following keys: - - 'data_dir': str, path to the client's data directory. - - 'patient_id_in_file_identifier': str, identifier for patient ID in file. - - 'modality_dict': dict, dictionary mapping modalities. - - 'dataset_format': str, format of the dataset. - - 'nnunet_root_folder': str, path to the nnUNet root folder. - - 'client_name': str, name of the client. - - 'bundle_root': str, optional, path to the bundle root directory. - experiment : dict - Dictionary containing experiment-specific configurations with the following keys: - - 'dataset_name_or_id': str, name or ID of the dataset. - - 'experiment_name': str, name of the experiment. - - 'tracking_uri': str, URI for tracking. - - 'nnunet_plans': str, optional, nnUNet plans. - - 'nnunet_trainer': str, optional, nnUNet trainer. - root_dir : str - Root directory where the configuration files will be generated. - script_dir : str - Directory containing the scripts to be used. - nvflare_exec : str - Path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "train" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Train nnUNet", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetTrainProcessor", "args": {}}, - {"id": "json_generator", "path": "monai.nvflare.json_generator.nnUNetValSummaryJsonGenerator", "args": {}}, - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "data_dir": clients[client_id]["data_dir"], - "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], - "modality_dict": clients[client_id]["modality_dict"], - "dataset_format": clients[client_id]["dataset_format"], - "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], - "nnunet_config": { - "dataset_name_or_id": experiment["dataset_name_or_id"], - "experiment_name": experiment["experiment_name"], - }, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - if "bundle_root" in clients[client_id]: - client["executors"][0]["executor"]["args"]["bundle_root"] = clients[client_id]["bundle_root"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def prepare_bundle_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Prepare the configuration files for the nnUNet bundle and generate the job configurations for NVFlare. - - Parameters - ---------- - clients : dict - A dictionary containing client information. Keys are client IDs and values are dictionaries with client details. - experiment : dict - A dictionary containing experiment details such as 'experiment_name', 'tracking_uri', and optional - configurations like 'bundle_extra_config', 'nnunet_plans', and 'nnunet_trainer'. - root_dir : str - The root directory where the configuration files and job directories will be created. - script_dir : str - The directory containing the necessary scripts for NVFlare. - nvflare_exec : str - The path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "prepare_bundle" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Prepare nnUNet Bundle", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "nnunet_processor", - "path": "monai.nvflare.response_processor.nnUNetBundlePrepareProcessor", - "args": {}, - } - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 600000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "nnunet_config": {"experiment_name": experiment["experiment_name"]}, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "bundle_extra_config" in experiment: - client["executors"][0]["executor"]["args"]["train_extra_configs"] = experiment["bundle_extra_config"] - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - if "bundle_root" in clients[client_id]: - client["executors"][0]["executor"]["args"]["bundle_root"] = clients[client_id]["bundle_root"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def train_fl_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generate federated learning job configurations for NVFlare. - - Parameters - ---------- - clients : dict - Dictionary containing client names and their configurations. - experiment : dict - Dictionary containing experiment parameters such as number of rounds and local epochs. - root_dir : str - Root directory where the job configurations will be saved. - script_dir : str - Directory containing the necessary scripts for NVFlare. - nvflare_exec : str - Path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "train_fl_nnunet_bundle" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = { - "description": "Federated Learning with nnUNet-MONAI Bundle", - "client_category": "Executor", - "controller_type": "server", - } - - meta = { - "name": f"{task_name}", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": len(list(clients.keys())), - "mandatory_clients": list(clients.keys()), - } - - for client_name, client_config in clients.items(): - meta["deploy_map"][f"{task_name}-{client_name}"] = [client_name] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "min_clients": len(list(clients.keys())), - "num_rounds": experiment["num_rounds"], - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "persistor", - "path": "monai_nvflare.monai_bundle_persistor.MonaiBundlePersistor", - "args": { - "bundle_root": experiment["server_bundle_root"], - "config_train_filename": "configs/train.yaml", - "network_def_key": "network_def_fl", - }, - }, - {"id": "shareable_generator", "name": "FullModelShareableGenerator", "args": {}}, - { - "id": "aggregator", - "name": "InTimeAccumulateWeightedAggregator", - "args": {"expected_data_kind": "WEIGHT_DIFF"}, - }, - {"id": "model_selector", "name": "IntimeModelSelector", "args": {}}, - {"id": "model_locator", "name": "PTFileModelLocator", "args": {"pt_persistor_id": "persistor"}}, - {"id": "json_generator", "name": "ValidationJsonGenerator", "args": {}}, - ], - "workflows": [ - { - "id": "scatter_gather_ctl", - "name": "ScatterAndGather", - "args": { - "min_clients": "{min_clients}", - "num_rounds": "{num_rounds}", - "start_round": experiment["start_round"], - "wait_time_after_min_received": 10, - "aggregator_id": "aggregator", - "persistor_id": "persistor", - "shareable_generator_id": "shareable_generator", - "train_task_name": "train", - "train_timeout": 0, - }, - }, - { - "id": "cross_site_model_eval", - "name": "CrossSiteModelEval", - "args": { - "model_locator_id": "model_locator", - "submit_model_timeout": 600, - "validation_timeout": 6000, - "cleanup_models": True, - }, - }, - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_name, client_config in clients.items(): - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "executors": [ - { - "tasks": ["train", "submit_model", "validate"], - "executor": { - "id": "executor", - # "path": "monai_algo.ClientnnUNetAlgoExecutor", - "path": "monai_nvflare.client_algo_executor.ClientAlgoExecutor", - "args": {"client_algo_id": "client_algo", "key_metric": "Val_Dice"}, - }, - } - ], - "components": [ - { - "id": "client_algo", - # "path": "monai_algo.MonaiAlgonnUNet", - "path": "monai.fl.client.monai_algo.MonaiAlgo", - "args": { - "bundle_root": client_config["bundle_root"], - "config_train_filename": "configs/train.yaml", - "save_dict_key": "network_weights", - "local_epochs": experiment["local_epochs"], - "train_kwargs": {"nnunet_root_folder": client_config["nnunet_root_folder"]}, - }, - } - ], - } - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-{client_name}").mkdir(parents=True, exist_ok=True) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-{client_name}", "config_fed_client.conf"), "w" - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def generate_configs(client_files, experiment_file, script_dir, job_dir, nvflare_exec="nvflare"): - """ - Generate configuration files for NVFlare job. - - Parameters - ---------- - client_files : list of str - List of file paths to client configuration files. - experiment_file : str - File path to the experiment configuration file. - script_dir : str - Directory path where the scripts are located. - job_dir : str - Directory path where the job configurations will be saved. - nvflare_exec : str, optional - NVFlare executable command, by default "nvflare". - - Returns - ------- - None - """ - clients = {} - for client_id in client_files: - with open(client_id) as f: - client_name = Path(client_id).name - clients[client_name.split(".")[0]] = yaml.safe_load(f) - - with open(experiment_file) as f: - experiment = yaml.safe_load(f) - - check_client_packages_config(clients, experiment, job_dir, script_dir, nvflare_exec) - prepare_config(clients, experiment, job_dir, script_dir, nvflare_exec) - plan_and_preprocess_config(clients, experiment, job_dir, script_dir, nvflare_exec) - preprocess_config(clients, experiment, job_dir, script_dir, nvflare_exec) - train_config(clients, experiment, job_dir, script_dir, nvflare_exec) - prepare_bundle_config(clients, experiment, job_dir, script_dir, nvflare_exec) - train_fl_config(clients, experiment, job_dir, script_dir, nvflare_exec) diff --git a/monai/nvflare/nvflare_nnunet.py b/monai/nvflare/nvflare_nnunet.py deleted file mode 100644 index 72dc062ccd..0000000000 --- a/monai/nvflare/nvflare_nnunet.py +++ /dev/null @@ -1,695 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import json -import logging -import multiprocessing -import os -import pathlib -import random -import re -import shutil -import subprocess -from importlib.metadata import version -from pathlib import Path - -import mlflow -import numpy as np -import pandas as pd -import psutil -import yaml - -import monai -from monai.apps.nnunet import nnUNetV2Runner -from monai.apps.nnunet.nnunet_bundle import convert_monai_bundle_to_nnunet -from monai.bundle import ConfigParser - - -def train( - nnunet_root_dir, - experiment_name, - client_name, - tracking_uri, - dataset_name_or_id, - trainer_class_name="nnUNetTrainer", - nnunet_plans_name="nnUNetPlans", - run_with_bundle=False, - fold=0, - bundle_root=None, - mlflow_token=None, -): - """ - - Train a nnUNet model and log metrics to MLflow. - - Parameters - ---------- - nnunet_root_dir : str - Root directory for nnUNet. - experiment_name : str - Name of the MLflow experiment. - client_name : str - Name of the client. - tracking_uri : str - URI for MLflow tracking server. - dataset_name_or_id : str - Name or ID of the dataset. - trainer_class_name : str, optional - Name of the nnUNet trainer class, by default "nnUNetTrainer". - nnunet_plans_name : str, optional - Name of the nnUNet plans, by default "nnUNetPlans". - run_with_bundle : bool, optional - Whether to run with MONAI bundle, by default False. - fold : int, optional - Fold number for cross-validation, by default 0. - bundle_root : str, optional - Root directory for MONAI bundle, by default None. - mlflow_token : str, optional - Token for MLflow authentication, by default None. - - Returns - ------- - dict - Dictionary containing validation summary metrics. - """ - data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") - runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) - - if not run_with_bundle: - runner.train_single_model(config="3d_fullres", fold=fold) - else: - os.environ["BUNDLE_ROOT"] = bundle_root - os.environ["PYTHONPATH"] = os.environ["PYTHONPATH"] + ":" + bundle_root - monai.bundle.run( - config_file=Path(bundle_root).joinpath("configs/train.yaml"), - bundle_root=bundle_root, - nnunet_trainer_class_name=trainer_class_name, - mlflow_experiment_name=experiment_name, - mlflow_run_name="run_" + client_name, - tracking_uri=tracking_uri, - fold_id=fold, - ) - nnunet_config = {"dataset_name_or_id": dataset_name_or_id, "nnunet_trainer": trainer_class_name} - convert_monai_bundle_to_nnunet(nnunet_config, bundle_root) - runner.train_single_model(config="3d_fullres", fold=fold, val="") - - if mlflow_token is not None: - os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token - if tracking_uri is not None: - mlflow.set_tracking_uri(tracking_uri) - - try: - mlflow.create_experiment(experiment_name) - except Exception as e: - print(e) - mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) - - filter = f""" - tags."client" = "{client_name}" - """ - - runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) - - validation_summary = os.path.join( - runner.nnunet_results, - runner.dataset_name, - f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", - f"fold_{fold}", - "validation", - "summary.json", - ) - - dataset_file = os.path.join( - runner.nnunet_results, - runner.dataset_name, - f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", - "dataset.json", - ) - - with open(dataset_file, "r") as f: - dataset_dict = json.load(f) - labels = dataset_dict["labels"] - labels = {str(v): k for k, v in labels.items()} - - with open(validation_summary, "r") as f: - validation_summary_dict = json.load(f) - - if len(runs) == 0: - with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): - for label in validation_summary_dict["mean"]: - for metric in validation_summary_dict["mean"][label]: - label_name = labels[label] - mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric])) - - else: - with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): - for label in validation_summary_dict["mean"]: - for metric in validation_summary_dict["mean"][label]: - label_name = labels[label] - mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric])) - - return validation_summary_dict - - -def preprocess(nnunet_root_dir, dataset_name_or_id, nnunet_plans_file_path=None, trainer_class_name="nnUNetTrainer"): - """ - Preprocess the dataset for nnUNet training. - - Parameters - ---------- - nnunet_root_dir : str - The root directory of the nnUNet project. - dataset_name_or_id : str or int - The name or ID of the dataset to preprocess. - nnunet_plans_file_path : Path, optional - The file path to the nnUNet plans file. If None, default plans will be used. Default is None. - trainer_class_name : str, optional - The name of the trainer class to use. Default is "nnUNetTrainer". - - Returns - ------- - dict - The nnUNet plans dictionary. - """ - - data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") - runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) - - nnunet_plans_name = nnunet_plans_file_path.name.split(".")[0] - from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name - - dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id)) - - Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name).mkdir(parents=True, exist_ok=True) - - shutil.copy( - Path(nnunet_root_dir).joinpath("nnUNet_raw_data_base", dataset_name, "dataset.json"), - Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name, "dataset.json"), - ) - if nnunet_plans_file_path is not None: - with open(nnunet_plans_file_path, "r") as f: - nnunet_plans = json.load(f) - nnunet_plans["original_dataset_name"] = nnunet_plans["dataset_name"] - nnunet_plans["dataset_name"] = dataset_name - json.dump( - nnunet_plans, - open( - Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name, f"{nnunet_plans_name}.json"), - "w", - ), - indent=4, - ) - - runner.extract_fingerprints(npfp=2, verify_dataset_integrity=True) - runner.preprocess(c=["3d_fullres"], n_proc=[2], overwrite_plans_name=nnunet_plans_name) - - return nnunet_plans - - -def plan_and_preprocess( - nnunet_root_dir, - dataset_name_or_id, - client_name, - experiment_name, - tracking_uri, - mlflow_token=None, - nnunet_plans_name="nnUNetPlans", - trainer_class_name="nnUNetTrainer", -): - """ - Plan and preprocess the dataset using nnUNetV2Runner and log the plans to MLflow. - - Parameters - ---------- - nnunet_root_dir : str - The root directory of nnUNet. - dataset_name_or_id : str or int - The name or ID of the dataset to be processed. - client_name : str - The name of the client. - experiment_name : str - The name of the MLflow experiment. - tracking_uri : str - The URI of the MLflow tracking server. - mlflow_token : str, optional - The token for MLflow authentication (default is None). - nnunet_plans_name : str, optional - The name of the nnUNet plans (default is "nnUNetPlans"). - trainer_class_name : str, optional - The name of the nnUNet trainer class (default is "nnUNetTrainer"). - - Returns - ------- - dict - The nnUNet plans as a dictionary. - """ - - data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") - - runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) - - runner.plan_and_process( - npfp=2, verify_dataset_integrity=True, c=["3d_fullres"], n_proc=[2], overwrite_plans_name=nnunet_plans_name - ) - - preprocessed_folder = runner.nnunet_preprocessed - - from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name - - dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id)) - - with open(Path(preprocessed_folder).joinpath(f"{dataset_name}", nnunet_plans_name + ".json"), "r") as f: - nnunet_plans = json.load(f) - - if mlflow_token is not None: - os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token - if tracking_uri is not None: - mlflow.set_tracking_uri(tracking_uri) - - try: - mlflow.create_experiment(experiment_name) - except Exception as e: - print(e) - mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) - - filter = f""" - tags."client" = "{client_name}" - """ - - runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) - - if len(runs) == 0: - with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): - mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json") - - else: - with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): - mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json") - - return nnunet_plans - - -def prepare_data_folder( - data_dir, - nnunet_root_dir, - dataset_name_or_id, - modality_dict, - experiment_name, - client_name, - dataset_format, - modality_list = None, - tracking_uri=None, - mlflow_token=None, - subfolder_suffix=None, - patient_id_in_file_identifier=True, - trainer_class_name="nnUNetTrainer", -): - """ - Prepare the data folder for nnUNet training and log the data to MLflow. - - Parameters - ---------- - data_dir : str - Directory containing the dataset. - nnunet_root_dir : str - Root directory for nnUNet. - dataset_name_or_id : str - Name or ID of the dataset. - modality_dict : dict - Dictionary mapping modality IDs to file suffixes. - experiment_name : str - Name of the MLflow experiment. - client_name : str - Name of the client. - dataset_format : str - Format of the dataset. Supported formats are "subfolders", "decathlon", and "nnunet". - tracking_uri : str, optional - URI for MLflow tracking server. - modality_list : list, optional - List of modalities. Default is None. - mlflow_token : str, optional - Token for MLflow authentication. - subfolder_suffix : str, optional - Suffix for subfolder names. - patient_id_in_file_identifier : bool, optional - Whether patient ID is included in file identifier. Default is True. - trainer_class_name : str, optional - Name of the nnUNet trainer class. Default is "nnUNetTrainer". - - Returns - ------- - dict - Dictionary containing the training and testing data lists. - """ - if dataset_format == "subfolders": - if subfolder_suffix is not None: - data_list = { - "training": [ - { - modality_id: ( - str( - pathlib.Path(f.name).joinpath( - f.name[: -len(subfolder_suffix)] + modality_dict[modality_id] - ) - ) - if patient_id_in_file_identifier - else str(pathlib.Path(f.name).joinpath(modality_dict[modality_id])) - ) - for modality_id in modality_dict - } - for f in os.scandir(data_dir) - if f.is_dir() - ], - "testing": [], - } - else: - data_list = { - "training": [ - { - modality_id: ( - str(pathlib.Path(f.name).joinpath(f.name + modality_dict[modality_id])) - if patient_id_in_file_identifier - else str(pathlib.Path(f.name).joinpath(modality_dict[modality_id])) - ) - for modality_id in modality_dict - } - for f in os.scandir(data_dir) - if f.is_dir() - ], - "testing": [], - } - elif dataset_format == "decathlon" or dataset_format == "nnunet": - cases = [] - - for f in os.scandir(Path(data_dir).joinpath("imagesTr")): - if f.is_file(): - for modality_suffix in list(modality_dict.values()): - if f.name.endswith(modality_suffix) and modality_suffix != ".nii.gz": - cases.append(f.name[: -len(modality_suffix)]) - if len(np.unique(list(modality_dict.values()))) == 1 and ".nii.gz" in list(modality_dict.values()): - cases.append(f.name[: -len(".nii.gz")]) - cases = np.unique(cases) - data_list = { - "training": [ - { - modality_id: str(Path("imagesTr").joinpath(case + modality_dict[modality_id])) - for modality_id in modality_dict - if modality_id != "label" - } - for case in cases - ], - "testing": [], - } - for idx, case in enumerate(data_list["training"]): - modality_id = list(modality_dict.keys())[0] - case_id = Path(case[modality_id]).name[: -len(modality_dict[modality_id])] - data_list["training"][idx]["label"] = str(Path("labelsTr").joinpath(case_id + modality_dict["label"])) - else: - raise ValueError("Dataset format not supported") - - for idx, train_case in enumerate(data_list["training"]): - for modality_id in modality_dict: - data_list["training"][idx][modality_id + "_is_file"] = ( - Path(data_dir).joinpath(data_list["training"][idx][modality_id]).is_file() - ) - if "image" not in data_list["training"][idx] and modality_id != "label": - data_list["training"][idx]["image"] = data_list["training"][idx][modality_id] - data_list["training"][idx]["fold"] = 0 - - random.seed(42) - random.shuffle(data_list["training"]) - - data_list["testing"] = [data_list["training"][0]] - - num_folds = 5 - fold_size = len(data_list["training"]) // num_folds - for i in range(num_folds): - for j in range(fold_size): - data_list["training"][i * fold_size + j]["fold"] = i - - datalist_file = Path(data_dir).joinpath(f"{experiment_name}_folds.json") - with open(datalist_file, "w", encoding="utf-8") as f: - json.dump(data_list, f, ensure_ascii=False, indent=4) - - os.makedirs(nnunet_root_dir, exist_ok=True) - - if modality_list is None: - modality_list = [k for k in modality_dict.keys() if k != "label"] - - data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") - data_src = { - "modality": modality_list, - "dataset_name_or_id": dataset_name_or_id, - "datalist": str(datalist_file), - "dataroot": str(data_dir), - } - - ConfigParser.export_config_file(data_src, data_src_cfg) - - if dataset_format != "nnunet": - runner = nnUNetV2Runner( - input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir - ) - runner.convert_dataset() - else: - ... - - if mlflow_token is not None: - os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token - if tracking_uri is not None: - mlflow.set_tracking_uri(tracking_uri) - - try: - mlflow.create_experiment(experiment_name) - mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) - except Exception as e: - print(e) - mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) - - filter = f""" - tags."client" = "{client_name}" - """ - - runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) - - try: - if len(runs) == 0: - with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): - mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json") - else: - with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): - mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json") - except (BrokenPipeError, ConnectionError) as e: - logging.error(f"Failed to log data to MLflow: {e}") - - return data_list - - -def check_packages(packages): - """ - Check if the specified packages are installed and return a report. - - Parameters - ---------- - packages : list - A list of package names (str) or dictionaries with keys "import_name" and "package_name". - - Returns - ------- - dict - A dictionary where the keys are package names and the values are strings indicating whether - the package is installed and its version if applicable. - - Examples - -------- - >>> check_packages(["numpy", "nonexistent_package"]) - {'numpy': 'numpy 1.21.0 is installed.', 'nonexistent_package': 'nonexistent_package is not installed.'} - >>> check_packages([{"import_name": "torch", "package_name": "torch"}]) - {'torch': 'torch 1.9.0 is installed.'} - """ - report = {} - for package in packages: - try: - if isinstance(package, dict): - __import__(package["import_name"]) - package_version = version(package["package_name"]) - name = package["package_name"] - print(f"{name} {package_version} is installed.") - report[name] = f"{name} {package_version} is installed." - else: - - __import__(package) - package_version = version(package) - print(f"{package} {package_version} is installed.") - report[package] = f"{package} {package_version} is installed." - - except ImportError: - print(f"{package} is not installed.") - report[package] = f"{package} is not installed." - - return report - - -def check_host_config(): - """ - Collects and returns the host configuration details including GPU, CPU, and memory information. - - Returns - ------- - dict - A dictionary containing the following keys and their corresponding values: - - Config values from `monai.config.deviceconfig.get_config_values()` - - Optional config values from `monai.config.deviceconfig.get_optional_config_values()` - - GPU information including number of GPUs, CUDA version, cuDNN version, and GPU names and memory - - CPU core count - - Total memory in GB - - Memory usage percentage - """ - params_dict = {} - config_values = monai.config.deviceconfig.get_config_values() - for k in config_values: - params_dict[re.sub("[()]", " ", str(k))] = config_values[k] - optional_config_values = monai.config.deviceconfig.get_optional_config_values() - - for k in optional_config_values: - params_dict[re.sub("[()]", " ", str(k))] = optional_config_values[k] - - gpu_info = monai.config.deviceconfig.get_gpu_info() - allowed_keys = ["Num GPUs", "Has Cuda", "CUDA Version", "cuDNN enabled", "cuDNN Version"] - for i in range(gpu_info["Num GPUs"]): - allowed_keys.append(f"GPU {i} Name") - allowed_keys.append(f"GPU {i} Total memory GB ") - - for k in gpu_info: - if re.sub("[()]", " ", str(k)) in allowed_keys: - params_dict[re.sub("[()]", " ", str(k))] = str(gpu_info[k]) - - with open("nvidia-smi.log", "w") as f_e: - subprocess.run("nvidia-smi", stderr=f_e, stdout=f_e) - - params_dict["CPU_Cores"] = multiprocessing.cpu_count() - - vm = psutil.virtual_memory() - - params_dict["Total Memory"] = vm.total / (1024 * 1024 * 1024) - params_dict["Memory Used %"] = vm.percent - - return params_dict - - -def prepare_bundle(bundle_config, train_extra_configs=None): - """ - Prepare the bundle configuration for training and evaluation. - - Parameters - ---------- - bundle_config : dict - Dictionary containing the bundle configuration. Expected keys are: - - "bundle_root": str, root directory of the bundle. - - "tracking_uri": str, URI for tracking. - - "mlflow_experiment_name": str, name of the MLflow experiment. - - "mlflow_run_name": str, name of the MLflow run. - - "nnunet_plans_identifier": str, optional, identifier for nnUNet plans. - - "nnunet_trainer_class_name": str, optional, class name for nnUNet trainer. - train_extra_configs : dict, optional - Additional configurations for training. If provided, expected keys are: - - "resume_epoch": int, epoch to resume training from. - - Any other key-value pairs to be added to the training configuration. - - Returns - ------- - None - """ - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml")) as f: - train_config = yaml.safe_load(f) - train_config["bundle_root"] = bundle_config["bundle_root"] - train_config["tracking_uri"] = bundle_config["tracking_uri"] - train_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"] - train_config["mlflow_run_name"] = bundle_config["mlflow_run_name"] - - train_config["data_src_cfg"] = "$@nnunet_root_folder+'/data_src_cfg.yaml'" - train_config["runner"] = { - "_target_": "nnUNetV2Runner", - "input_config": "$@data_src_cfg", - "trainer_class_name": "@nnunet_trainer_class_name", - "work_dir": "@nnunet_root_folder", - } - - train_config["network"] = "$@nnunet_trainer.network._orig_mod" - - train_handlers = train_config["train_handlers"]["handlers"] - - for idx, handler in enumerate(train_handlers): - if handler["_target_"] == "ValidationHandler": - train_handlers.pop(idx) - break - - train_config["train_handlers"]["handlers"] = train_handlers - - if train_extra_configs is not None and "resume_epoch" in train_extra_configs: - resume_epoch = train_extra_configs["resume_epoch"] - train_config["initialize"] = [ - "$monai.utils.set_determinism(seed=123)", - "$@runner.dataset_name_or_id", - f"$src.trainer.reload_checkpoint(@train#trainer, {resume_epoch}, @iterations, @ckpt_dir, @lr_scheduler)", - ] - else: - train_config["initialize"] = ["$monai.utils.set_determinism(seed=123)", "$@runner.dataset_name_or_id"] - - if "Val_Dice" in train_config["val_key_metric"]: - train_config["val_key_metric"] = {"Val_Dice_Local": train_config["val_key_metric"]["Val_Dice"]} - - if "Val_Dice_per_class" in train_config["val_additional_metrics"]: - train_config["val_additional_metrics"] = { - "Val_Dice_per_class_Local": train_config["val_additional_metrics"]["Val_Dice_per_class"] - } - if "nnunet_plans_identifier" in bundle_config: - train_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"] - - if "nnunet_trainer_class_name" in bundle_config: - train_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"] - - if train_extra_configs is not None: - for key in train_extra_configs: - train_config[key] = train_extra_configs[key] - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.json"), "w") as f: - json.dump(train_config, f) - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml"), "w") as f: - yaml.dump(train_config, f) - - if not Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml").exists(): - shutil.copy( - Path(bundle_config["bundle_root"]).joinpath("nnUNet", "evaluator", "evaluator.yaml"), - Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), - ) - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml")) as f: - evaluate_config = yaml.safe_load(f) - evaluate_config["bundle_root"] = bundle_config["bundle_root"] - - evaluate_config["tracking_uri"] = bundle_config["tracking_uri"] - evaluate_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"] - evaluate_config["mlflow_run_name"] = bundle_config["mlflow_run_name"] - - if "nnunet_plans_identifier" in bundle_config: - evaluate_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"] - if "nnunet_trainer_class_name" in bundle_config: - evaluate_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"] - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.json"), "w") as f: - json.dump(evaluate_config, f) - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), "w") as f: - yaml.dump(evaluate_config, f) diff --git a/monai/nvflare/response_processor.py b/monai/nvflare/response_processor.py deleted file mode 100644 index a02d307220..0000000000 --- a/monai/nvflare/response_processor.py +++ /dev/null @@ -1,342 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from nvflare.apis.client import Client -from nvflare.apis.dxo import DataKind, from_shareable -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable -from nvflare.app_common.abstract.response_processor import ResponseProcessor - - -class nnUNetPrepareProcessor(ResponseProcessor): - """ - A processor class for preparing nnUNet data in a federated learning context. - - Methods - ------- - __init__(): - Initializes the nnUNetPrepareProcessor with an empty data dictionary. - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable: - Creates and returns a Shareable object for the given task name. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - Processes the response from a client. Validates the response and updates the data dictionary if valid. - final_process(fl_ctx: FLContext) -> bool: - Finalizes the processing by setting the client data dictionary in the federated learning context. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - self.data_dict = {} - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - data_dict = dxo.data - - if not data_dict: - self.log_error(fl_ctx, f"No dataset_dict found from client {client.name}") - return False - - self.data_dict[client.name] = data_dict - - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - if not self.data_dict: - self.log_error(fl_ctx, "no data_prepare_dict from clients") - return False - - # must set sticky to True so other controllers can get it! - fl_ctx.set_prop("client_data_dict", self.data_dict, private=True, sticky=True) - return True - - -class nnUNetPackageReportProcessor(ResponseProcessor): - """ - A processor for handling nnUNet package reports in a federated learning context. - - Attributes - ---------- - package_report : dict - A dictionary to store package reports from clients. - - Methods - ------- - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable - Creates task data for a given task name and federated learning context. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool - Processes the response from a client for a given task name and federated learning context. - final_process(fl_ctx: FLContext) -> bool - Final processing step to handle the collected package reports. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - self.package_report = {} - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - package_report = dxo.data - - if not package_report: - self.log_error(fl_ctx, f"No package_report found from client {client.name}") - return False - - self.package_report[client.name] = package_report - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - if not self.package_report: - self.log_error(fl_ctx, "no plan_dict from client") - return False - - # must set sticky to True so other controllers can get it! - fl_ctx.set_prop("package_report", self.package_report, private=True, sticky=True) - return True - - -class nnUNetPlanProcessor(ResponseProcessor): - """ - nnUNetPlanProcessor is a class that processes responses from clients in a federated learning context. - It inherits from the ResponseProcessor class and is responsible for handling and validating the - responses, extracting the necessary data, and storing it for further use. - - Attributes - ---------- - plan_dict : dict - A dictionary to store the plan data received from clients. - - Methods - ------- - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable - Creates and returns a Shareable object for the given task name. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool - Processes the response from a client, validates it, and stores the plan data if valid. - final_process(fl_ctx: FLContext) -> bool - Finalizes the processing by setting the plan data in the federated learning context. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - self.plan_dict = {} - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - plan_dict = dxo.data - - if not plan_dict: - self.log_error(fl_ctx, f"No plan_dict found from client {client.name}") - return False - - self.plan_dict[client.name] = plan_dict - - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - if not self.plan_dict: - self.log_error(fl_ctx, "no plan_dict from client") - return False - - # must set sticky to True so other controllers can get it! - fl_ctx.set_prop("nnunet_plans", self.plan_dict, private=True, sticky=True) - return True - - -class nnUNetTrainProcessor(ResponseProcessor): - """ - A processor class for handling training responses in the nnUNet framework. - - Attributes - ---------- - val_summary_dict : dict - A dictionary to store validation summaries from clients. - Methods - ------- - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable - Creates task data for a given task name and FLContext. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool - Processes the response from a client for a given task name and FLContext. - final_process(fl_ctx: FLContext) -> bool - Final processing step to handle the collected validation summaries. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - self.val_summary_dict = {} - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - val_summary_dict = dxo.data - - if not val_summary_dict: - self.log_error(fl_ctx, f"No val_summary_dict found from client {client.name}") - return False - - self.val_summary_dict[client.name] = val_summary_dict - - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - if not self.val_summary_dict: - self.log_error(fl_ctx, "no val_summary_dict from client") - return False - - # must set sticky to True so other controllers can get it! - fl_ctx.set_prop("val_summary_dict", self.val_summary_dict, private=True, sticky=True) - return True - - -class nnUNetBundlePrepareProcessor(ResponseProcessor): - """ - A processor class for preparing nnUNet bundles in a federated learning context. - - Methods - ------- - __init__(): - Initializes the nnUNetBundlePrepareProcessor instance. - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable: - Creates task data for a given task name and federated learning context. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - Processes the response from a client and validates it. - final_process(fl_ctx: FLContext) -> bool: - Final processing step after all client responses have been processed. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - - return True diff --git a/requirements-dev.txt b/requirements-dev.txt index a31b83a59e..46708026b8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -61,6 +61,4 @@ huggingface_hub pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 onnx_graphsurgeon -polygraphy -pyhocon -odict \ No newline at end of file +polygraphy \ No newline at end of file From 60185d1b3de2a2181ca8641794e80f86e5d5820d Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 11:54:59 +0000 Subject: [PATCH 40/67] Add new functions to nnunet_bundle for converting between MONAI and nnU-Net formats --- docs/source/apps.rst | 2 + monai/apps/nnunet/nnunet_bundle.py | 61 ++++++++++++++---------------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index e27e30c0bf..e55c6c17c6 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -287,3 +287,5 @@ FastMRIReader .. autofunction:: monai.apps.nnunet.get_nnunet_trainer .. autofunction:: monai.apps.nnunet.get_nnunet_monai_predictor .. autofunction:: monai.apps.nnunet.convert_nnunet_to_monai_bundle +.. autofunction:: monai.apps.nnunet.convert_monai_bundle_to_nnunet +.. autofunction:: monai.apps.nnunet.get_network_from_nnunet_plans \ No newline at end of file diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 1581e325f1..4ee1f94f9a 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -13,6 +13,7 @@ import os import shutil from pathlib import Path +from typing import Optional, Union import numpy as np import torch @@ -21,11 +22,17 @@ from monai.data.meta_tensor import MetaTensor from monai.utils import optional_import -from typing import Union, Optional join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") -__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "convert_monai_bundle_to_nnunet","ModelnnUNetWrapper"] +__all__ = [ + "get_nnunet_trainer", + "get_nnunet_monai_predictor", + "get_network_from_nnunet_plans", + "convert_nnunet_to_monai_bundle", + "convert_monai_bundle_to_nnunet", + "ModelnnUNetWrapper", +] def get_nnunet_trainer( @@ -107,7 +114,6 @@ def get_nnunet_trainer( ) raise e - from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint nnunet_trainer = get_trainer_from_args( @@ -178,7 +184,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode plans_manager = PlansManager(plans) parameters = [] - + checkpoint = torch.load( join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") ) @@ -190,9 +196,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode else None ) if Path(model_training_output_dir).joinpath(model_name).is_file(): - monai_checkpoint = torch.load( - join(model_training_output_dir, model_name), map_location=torch.device("cpu") - ) + monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) if "network_weights" in monai_checkpoint.keys(): parameters.append(monai_checkpoint["network_weights"]) else: @@ -230,10 +234,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode predictor.trainer_name = trainer_name predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes predictor.label_manager = plans_manager.get_label_manager(dataset_json) - if ( - ("nnUNet_compile" in os.environ.keys()) - and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) - ): + if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")): print("Using torch.compile") # End Block self.network_weights = self.predictor.network @@ -265,7 +266,11 @@ def forward(self, x: MetaTensor) -> MetaTensor: if "pixdim" in x.meta: properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} elif "affine" in x.meta: - spacing = [abs(x.meta['affine'][0][0].item()), abs(x.meta['affine'][1][1].item()), abs(x.meta['affine'][2][2].item())] + spacing = [ + abs(x.meta["affine"][0][0].item()), + abs(x.meta["affine"][1][1].item()), + abs(x.meta["affine"][2][2].item()), + ] properties_or_list_of_properties = {"spacing": spacing} else: properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]} @@ -348,9 +353,7 @@ def get_nnunet_monai_predictor(model_folder: str, model_name: str = "model.pt") return wrapper -def convert_nnunet_to_monai_bundle( - nnunet_config: dict, bundle_root_folder: str, fold: int = 0 -) -> None: +def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None: """ Convert nnUNet model checkpoints and configuration to MONAI bundle format. @@ -421,14 +424,14 @@ def convert_nnunet_to_monai_bundle( def get_network_from_nnunet_plans( - plans_file: str, - dataset_file: str, - configuration: str, - model_ckpt: Optional[str] = None, - model_key_in_ckpt: str = "model" + plans_file: str, + dataset_file: str, + configuration: str, + model_ckpt: Optional[str] = None, + model_key_in_ckpt: str = "model", ) -> torch.nn.Module: """ - Load and initialize a neural network based on nnUNet plans and configuration. + Load and initialize a nnUNet network based on nnUNet plans and configuration. Parameters ---------- @@ -481,11 +484,7 @@ def get_network_from_nnunet_plans( return network -def convert_monai_bundle_to_nnunet( - nnunet_config: dict, - bundle_root_folder: str, - fold: int = 0 -) -> None: +def convert_monai_bundle_to_nnunet(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None: """ Convert a MONAI bundle to nnU-Net format. @@ -520,11 +519,7 @@ def convert_monai_bundle_to_nnunet( from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name def subfiles( - folder: str, - join: bool = True, - prefix: Optional[str] = None, - suffix: Optional[str] = None, - sort: bool = True + folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True ) -> list[str]: if join: l = os.path.join # noqa: E741 @@ -562,7 +557,9 @@ def subfiles( epochs.sort() final_epoch: int = epochs[-1] - monai_last_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt") + monai_last_checkpoint: dict = torch.load( + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt" + ) best_checkpoints: list[str] = subfiles( Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), From b0ecb2c4dff1649e78e39aa2cef79c27557fc5a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 30 Mar 2025 11:58:45 +0000 Subject: [PATCH 41/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/apps.rst | 2 +- monai/apps/nnunet/nnunet_bundle.py | 4 ++-- requirements-dev.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index e55c6c17c6..3239dc5351 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -288,4 +288,4 @@ FastMRIReader .. autofunction:: monai.apps.nnunet.get_nnunet_monai_predictor .. autofunction:: monai.apps.nnunet.convert_nnunet_to_monai_bundle .. autofunction:: monai.apps.nnunet.convert_monai_bundle_to_nnunet -.. autofunction:: monai.apps.nnunet.get_network_from_nnunet_plans \ No newline at end of file +.. autofunction:: monai.apps.nnunet.get_network_from_nnunet_plans diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 4ee1f94f9a..051908e52e 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -522,9 +522,9 @@ def subfiles( folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True ) -> list[str]: if join: - l = os.path.join # noqa: E741 + l = os.path.join else: - l = lambda x, y: y # noqa: E741, E731 + l = lambda x, y: y # noqa: E731 res = [ l(folder, i.name) for i in Path(folder).iterdir() diff --git a/requirements-dev.txt b/requirements-dev.txt index 46708026b8..c9730ee651 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -61,4 +61,4 @@ huggingface_hub pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 onnx_graphsurgeon -polygraphy \ No newline at end of file +polygraphy From 678334e2e8536b6d02a20037533009cb83a2d994 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 12:01:21 +0000 Subject: [PATCH 42/67] DCO Remediation Commit for Simone Bendazzoli I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 1a30a0bb6f6557c0b9fcb165bd712edc1adfbb0a I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: ca851cdbc22bae81b32d99090e9dfa2b81668f1c I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: fee1bb06f20183c318dc310a2a95b82c3d9d4573 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 1972504ea62afd9060f899e4743d3b745a0f3643 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 5c633f21ffadd2fe224613c890e9f7331c6eab58 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 052ef648f4e22a77a358a794cd22a53aef760345 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 1c41164fe73a2612b70caaca7ce97a8ccad30f94 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 60185d1b3de2a2181ca8641794e80f86e5d5820d Signed-off-by: Simone Bendazzoli --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 4ee1f94f9a..13d89a249f 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -203,7 +203,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode parameters.append(monai_checkpoint) configuration_manager = plans_manager.get_configuration(configuration_name) - # restore network + import nnunetv2 from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels From 6178082edda4e745ded06528e464741b5a32e755 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 12:48:32 +0000 Subject: [PATCH 43/67] Add ModelnnUNetWrapper import to nnunet bundle --- monai/apps/nnunet/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/__init__.py b/monai/apps/nnunet/__init__.py index cdf96e0ce2..780ee7a861 100644 --- a/monai/apps/nnunet/__init__.py +++ b/monai/apps/nnunet/__init__.py @@ -12,12 +12,12 @@ from __future__ import annotations from .nnunet_bundle import ( + ModelnnUNetWrapper, convert_monai_bundle_to_nnunet, convert_nnunet_to_monai_bundle, get_network_from_nnunet_plans, get_nnunet_monai_predictor, get_nnunet_trainer, - ModelnnUNetWrapper ) from .nnunetv2_runner import nnUNetV2Runner from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json From 78a7d146ab9eebb0644c2bbfac64cb1ea2c11964 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 12:51:28 +0000 Subject: [PATCH 44/67] DCO Remediation Commit for Simone Bendazzoli I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 6178082edda4e745ded06528e464741b5a32e755 Signed-off-by: Simone Bendazzoli --- monai/apps/nnunet/nnunet_bundle.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index f3476fedc7..4ee1f94f9a 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -203,7 +203,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode parameters.append(monai_checkpoint) configuration_manager = plans_manager.get_configuration(configuration_name) - + # restore network import nnunetv2 from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels @@ -522,9 +522,9 @@ def subfiles( folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True ) -> list[str]: if join: - l = os.path.join + l = os.path.join # noqa: E741 else: - l = lambda x, y: y # noqa: E731 + l = lambda x, y: y # noqa: E741, E731 res = [ l(folder, i.name) for i in Path(folder).iterdir() From 8d132f8e74452fe6d64898c1bf01c542a97fffab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 30 Mar 2025 12:51:53 +0000 Subject: [PATCH 45/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/nnunet/nnunet_bundle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 4ee1f94f9a..051908e52e 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -522,9 +522,9 @@ def subfiles( folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True ) -> list[str]: if join: - l = os.path.join # noqa: E741 + l = os.path.join else: - l = lambda x, y: y # noqa: E741, E731 + l = lambda x, y: y # noqa: E731 res = [ l(folder, i.name) for i in Path(folder).iterdir() From 88a5e5a5e384cb0e1481f5c91472c0a443638163 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 13:35:27 +0000 Subject: [PATCH 46/67] Refactor nnUNet integration: update type hints and improve parameter definitions --- monai/apps/nnunet/nnunet_bundle.py | 31 +++++++++++++----------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 4ee1f94f9a..896974328c 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -13,7 +13,7 @@ import os import shutil from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, Any import numpy as np import torch @@ -24,6 +24,7 @@ join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") +nnUNetTrainer, _ = optional_import("nnunetv2.training.nnUNetTrainer", name="nnUNetTrainer") __all__ = [ "get_nnunet_trainer", @@ -47,7 +48,7 @@ def get_nnunet_trainer( disable_checkpointing: bool = False, device: str = "cuda", pretrained_model: Optional[str] = None, -) -> object: +) -> Union[nnUNetTrainer, Any]: # type: ignore """ Get the nnUNet trainer instance based on the provided configuration. The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, @@ -150,9 +151,9 @@ class ModelnnUNetWrapper(torch.nn.Module): Parameters ---------- - predictor : object + predictor : nnUNetPredictor The nnUNet predictor object used for inference. - model_folder : str + model_folder : Union[str, Path] The folder path where the model and related files are stored. model_name : str, optional The name of the model file, by default "model.pt". @@ -169,8 +170,8 @@ class ModelnnUNetWrapper(torch.nn.Module): This class integrates nnUNet model with MONAI framework by loading necessary configurations, restoring network architecture, and setting up the predictor for inference. """ - - def __init__(self, predictor: object, model_folder: str, model_name: str = "model.pt"): + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"): super().__init__() self.predictor = predictor @@ -299,7 +300,7 @@ def forward(self, x: MetaTensor) -> MetaTensor: return MetaTensor(out_tensor, meta=x.meta) -def get_nnunet_monai_predictor(model_folder: str, model_name: str = "model.pt") -> ModelnnUNetWrapper: +def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt") -> ModelnnUNetWrapper: """ Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. The model folder should contain the following files, created during training: @@ -326,7 +327,7 @@ def get_nnunet_monai_predictor(model_folder: str, model_name: str = "model.pt") Parameters ---------- - model_folder : str + model_folder : Union[str, Path] The folder where the model is stored. model_name : str, optional The name of the model file, by default "model.pt". @@ -429,7 +430,7 @@ def get_network_from_nnunet_plans( configuration: str, model_ckpt: Optional[str] = None, model_key_in_ckpt: str = "model", -) -> torch.nn.Module: +) -> Union[torch.nn.Module, Any]: """ Load and initialize a nnUNet network based on nnUNet plans and configuration. @@ -519,14 +520,10 @@ def convert_monai_bundle_to_nnunet(nnunet_config: dict, bundle_root_folder: str, from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name def subfiles( - folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True + folder: Union[str, Path], prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True ) -> list[str]: - if join: - l = os.path.join # noqa: E741 - else: - l = lambda x, y: y # noqa: E741, E731 res = [ - l(folder, i.name) + i.name for i in Path(folder).iterdir() if i.is_file() and (prefix is None or i.name.startswith(prefix)) @@ -549,8 +546,7 @@ def subfiles( nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") latest_checkpoints: list[str] = subfiles( - Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True, join=False - ) + Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True) epochs: list[int] = [] for latest_checkpoint in latest_checkpoints: epochs.append(int(latest_checkpoint[len("checkpoint_epoch=") : -len(".pt")])) @@ -565,7 +561,6 @@ def subfiles( Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_key_metric", sort=True, - join=False, ) key_metrics: list[str] = [] for best_checkpoint in best_checkpoints: From 49d0897580894b4bc736f7b6b97508dd71f4992c Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 13:37:34 +0000 Subject: [PATCH 47/67] Refactor nnUNet bundle: clean up import order and improve code formatting --- monai/apps/nnunet/nnunet_bundle.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 896974328c..f4bd7a4f46 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -13,7 +13,7 @@ import os import shutil from pathlib import Path -from typing import Optional, Union, Any +from typing import Any, Optional, Union import numpy as np import torch @@ -48,7 +48,7 @@ def get_nnunet_trainer( disable_checkpointing: bool = False, device: str = "cuda", pretrained_model: Optional[str] = None, -) -> Union[nnUNetTrainer, Any]: # type: ignore +) -> Union[nnUNetTrainer, Any]: # type: ignore """ Get the nnUNet trainer instance based on the provided configuration. The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, @@ -170,7 +170,9 @@ class ModelnnUNetWrapper(torch.nn.Module): This class integrates nnUNet model with MONAI framework by loading necessary configurations, restoring network architecture, and setting up the predictor for inference. """ + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"): super().__init__() self.predictor = predictor @@ -546,7 +548,8 @@ def subfiles( nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") latest_checkpoints: list[str] = subfiles( - Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True) + Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True + ) epochs: list[int] = [] for latest_checkpoint in latest_checkpoints: epochs.append(int(latest_checkpoint[len("checkpoint_epoch=") : -len(".pt")])) @@ -558,9 +561,7 @@ def subfiles( ) best_checkpoints: list[str] = subfiles( - Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), - prefix="checkpoint_key_metric", - sort=True, + Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_key_metric", sort=True ) key_metrics: list[str] = [] for best_checkpoint in best_checkpoints: From 18d5a4c51b6ef894d8a489007ce3f4cf2ca87e59 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 13:42:52 +0000 Subject: [PATCH 48/67] DCO Remediation Commit for Simone Bendazzoli I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 88a5e5a5e384cb0e1481f5c91472c0a443638163 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 49d0897580894b4bc736f7b6b97508dd71f4992c Signed-off-by: Simone Bendazzoli --- monai/apps/nnunet/nnunet_bundle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index f4bd7a4f46..aa0e4cd510 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -206,7 +206,6 @@ def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], m parameters.append(monai_checkpoint) configuration_manager = plans_manager.get_configuration(configuration_name) - # restore network import nnunetv2 from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels From 4ca028a13718783f3f880f6c77bf57d3f5594626 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 13:50:53 +0000 Subject: [PATCH 49/67] Enhance nnUNet bundle: add nnUNetPredictor import and update type hints in ModelnnUNetWrapper --- monai/apps/nnunet/nnunet_bundle.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index aa0e4cd510..350ef1b41c 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -25,6 +25,7 @@ join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") nnUNetTrainer, _ = optional_import("nnunetv2.training.nnUNetTrainer", name="nnUNetTrainer") +nnUNetPredictor, _ = optional_import("nnunetv2.inference.predict_from_raw_data", name="nnUNetPredictor") __all__ = [ "get_nnunet_trainer", @@ -171,9 +172,7 @@ class ModelnnUNetWrapper(torch.nn.Module): restoring network architecture, and setting up the predictor for inference. """ - from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor - - def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"): + def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore super().__init__() self.predictor = predictor @@ -228,18 +227,18 @@ def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], m enable_deep_supervision=False, ) - predictor.plans_manager = plans_manager - predictor.configuration_manager = configuration_manager - predictor.list_of_parameters = parameters - predictor.network = network - predictor.dataset_json = dataset_json - predictor.trainer_name = trainer_name - predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes - predictor.label_manager = plans_manager.get_label_manager(dataset_json) + predictor.plans_manager = plans_manager # type: ignore + predictor.configuration_manager = configuration_manager # type: ignore + predictor.list_of_parameters = parameters # type: ignore + predictor.network = network # type: ignore + predictor.dataset_json = dataset_json # type: ignore + predictor.trainer_name = trainer_name # type: ignore + predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes # type: ignore + predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")): print("Using torch.compile") # End Block - self.network_weights = self.predictor.network + self.network_weights = self.predictor.network # type: ignore def forward(self, x: MetaTensor) -> MetaTensor: """ @@ -282,7 +281,7 @@ def forward(self, x: MetaTensor) -> MetaTensor: image_or_list_of_images = x.cpu().numpy()[0, :] # input_files should be a list of file paths, one per modality - prediction_output = self.predictor.predict_from_list_of_npy_arrays( + prediction_output = self.predictor.predict_from_list_of_npy_arrays( # type: ignore image_or_list_of_images, None, properties_or_list_of_properties, From 050651c6367c64686e21ac08acf4f7edf6cb6909 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 13:52:05 +0000 Subject: [PATCH 50/67] DCO Remediation Commit for Simone Bendazzoli I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 4ca028a13718783f3f880f6c77bf57d3f5594626 Signed-off-by: Simone Bendazzoli --- monai/apps/nnunet/nnunet_bundle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 350ef1b41c..d49e5d014b 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -244,8 +244,6 @@ def forward(self, x: MetaTensor) -> MetaTensor: """ Forward pass for the nnUNet model. - :no-index: - Args: x (MetaTensor): Input tensor. If the input is a tuple, it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch. From b881cd3b8f5c1a18659097f2c44270bd05a7b176 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 14:08:21 +0000 Subject: [PATCH 51/67] Refactor nnUNet bundle: remove unused nnUNetTrainer and nnUNetPredictor imports, update type hint for predictor in ModelnnUNetWrapper --- monai/apps/nnunet/nnunet_bundle.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index d49e5d014b..366f90b844 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -24,8 +24,6 @@ join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") -nnUNetTrainer, _ = optional_import("nnunetv2.training.nnUNetTrainer", name="nnUNetTrainer") -nnUNetPredictor, _ = optional_import("nnunetv2.inference.predict_from_raw_data", name="nnUNetPredictor") __all__ = [ "get_nnunet_trainer", @@ -49,7 +47,7 @@ def get_nnunet_trainer( disable_checkpointing: bool = False, device: str = "cuda", pretrained_model: Optional[str] = None, -) -> Union[nnUNetTrainer, Any]: # type: ignore +) -> Any: # type: ignore """ Get the nnUNet trainer instance based on the provided configuration. The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, @@ -172,7 +170,7 @@ class ModelnnUNetWrapper(torch.nn.Module): restoring network architecture, and setting up the predictor for inference. """ - def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore + def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore super().__init__() self.predictor = predictor From 88a28d2195da98a6bb7bca80e4dadb6294f2a163 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 14:09:07 +0000 Subject: [PATCH 52/67] DCO Remediation Commit for Simone Bendazzoli I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: b881cd3b8f5c1a18659097f2c44270bd05a7b176 Signed-off-by: Simone Bendazzoli --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 366f90b844..e9bb5227ed 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -235,7 +235,7 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")): print("Using torch.compile") - # End Block + # End Block self.network_weights = self.predictor.network # type: ignore def forward(self, x: MetaTensor) -> MetaTensor: From 782f1fd99e1fcc2fdbb7c24c0c4bec03fea3d5d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 30 Mar 2025 14:09:33 +0000 Subject: [PATCH 53/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index e9bb5227ed..366f90b844 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -235,7 +235,7 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")): print("Using torch.compile") - # End Block + # End Block self.network_weights = self.predictor.network # type: ignore def forward(self, x: MetaTensor) -> MetaTensor: From 8e510e1dc2d1ff99e65a8a0ce6040ea885942e1c Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 17:47:23 +0000 Subject: [PATCH 54/67] Update docstring in get_nnunet_trainer to include link for supported trainer classes --- monai/apps/nnunet/nnunet_bundle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index e9bb5227ed..96ec8d84ff 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -82,7 +82,8 @@ def get_nnunet_trainer( fold : Union[int, str] The fold number or 'all' for cross-validation. trainer_class_name : str, optional - The class name of the trainer to be used. Default is 'nnUNetTrainer'. + The class name of the trainer to be used. Default is 'nnUNetTrainer'. + For a complete list of supported trainers, see https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants plans_identifier : str, optional Identifier for the plans to be used. Default is 'nnUNetPlans'. use_compressed_data : bool, optional From 88545579200b01513c9cc03feb0b9b3ddd333500 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:47:55 +0000 Subject: [PATCH 55/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 39f03fdd09..37fa708ce0 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -82,7 +82,7 @@ def get_nnunet_trainer( fold : Union[int, str] The fold number or 'all' for cross-validation. trainer_class_name : str, optional - The class name of the trainer to be used. Default is 'nnUNetTrainer'. + The class name of the trainer to be used. Default is 'nnUNetTrainer'. For a complete list of supported trainers, see https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants plans_identifier : str, optional Identifier for the plans to be used. Default is 'nnUNetPlans'. From 7d86a733d23ec892d29898b032722bd0be5eff15 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 17:48:39 +0000 Subject: [PATCH 56/67] DCO Remediation Commit for Simone Bendazzoli I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 8e510e1dc2d1ff99e65a8a0ce6040ea885942e1c Signed-off-by: Simone Bendazzoli --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 39f03fdd09..40373be685 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -83,7 +83,7 @@ def get_nnunet_trainer( The fold number or 'all' for cross-validation. trainer_class_name : str, optional The class name of the trainer to be used. Default is 'nnUNetTrainer'. - For a complete list of supported trainers, see https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants + For a complete list of supported trainers: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants plans_identifier : str, optional Identifier for the plans to be used. Default is 'nnUNetPlans'. use_compressed_data : bool, optional From 5422368a7d2230b1cccdb4b470bd92b1075b923e Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 17:59:27 +0000 Subject: [PATCH 57/67] Update docstring in get_nnunet_trainer for better readability of supported trainer link --- monai/apps/nnunet/nnunet_bundle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 584733da61..be3498b372 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -83,7 +83,8 @@ def get_nnunet_trainer( The fold number or 'all' for cross-validation. trainer_class_name : str, optional The class name of the trainer to be used. Default is 'nnUNetTrainer'. - For a complete list of supported trainers: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants + For a complete list of supported trainers: + https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants plans_identifier : str, optional Identifier for the plans to be used. Default is 'nnUNetPlans'. use_compressed_data : bool, optional From aff9cbef75bb01a6a72a3e83a531ef1913fc40b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:59:53 +0000 Subject: [PATCH 58/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index be3498b372..28cfebf530 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -83,7 +83,7 @@ def get_nnunet_trainer( The fold number or 'all' for cross-validation. trainer_class_name : str, optional The class name of the trainer to be used. Default is 'nnUNetTrainer'. - For a complete list of supported trainers: + For a complete list of supported trainers: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants plans_identifier : str, optional Identifier for the plans to be used. Default is 'nnUNetPlans'. From 5527ac8bbefc60a93804e1e376cff5aaedb8c150 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Sun, 30 Mar 2025 18:00:28 +0000 Subject: [PATCH 59/67] DCO Remediation Commit for Simone Bendazzoli I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 5422368a7d2230b1cccdb4b470bd92b1075b923e Signed-off-by: Simone Bendazzoli --- monai/apps/nnunet/nnunet_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index be3498b372..d3b506d5df 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -83,7 +83,7 @@ def get_nnunet_trainer( The fold number or 'all' for cross-validation. trainer_class_name : str, optional The class name of the trainer to be used. Default is 'nnUNetTrainer'. - For a complete list of supported trainers: + For a complete list of supported trainers, check: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants plans_identifier : str, optional Identifier for the plans to be used. Default is 'nnUNetPlans'. From 7d60fd74e52867c34178e4c530a01b6d2b0b9c07 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Mon, 31 Mar 2025 12:57:32 +0000 Subject: [PATCH 60/67] nvflare support --- monai/nvflare/__init__.py | 10 + monai/nvflare/json_generator.py | 179 +++ monai/nvflare/nnunet_executor.py | 334 +++++ monai/nvflare/nvflare_generate_job_configs.py | 1085 +++++++++++++++++ monai/nvflare/nvflare_nnunet.py | 695 +++++++++++ monai/nvflare/response_processor.py | 342 ++++++ 6 files changed, 2645 insertions(+) create mode 100644 monai/nvflare/__init__.py create mode 100644 monai/nvflare/json_generator.py create mode 100644 monai/nvflare/nnunet_executor.py create mode 100644 monai/nvflare/nvflare_generate_job_configs.py create mode 100644 monai/nvflare/nvflare_nnunet.py create mode 100644 monai/nvflare/response_processor.py diff --git a/monai/nvflare/__init__.py b/monai/nvflare/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/nvflare/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/nvflare/json_generator.py b/monai/nvflare/json_generator.py new file mode 100644 index 0000000000..9326a35837 --- /dev/null +++ b/monai/nvflare/json_generator.py @@ -0,0 +1,179 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import os.path + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_context import FLContext +from nvflare.widgets.widget import Widget + + +class PrepareJsonGenerator(Widget): + """ + A widget class to prepare and generate a JSON file containing data preparation configurations. + + Parameters + ---------- + results_dir : str, optional + The directory where the results will be stored (default is "prepare"). + json_file_name : str, optional + The name of the JSON file to be generated (default is "data_dict.json"). + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events during the federated learning process. Clears the data preparation configuration + at the start of a run and saves the configuration to a JSON file at the end of a run. + """ + + def __init__(self, results_dir="prepare", json_file_name="data_dict.json"): + super(PrepareJsonGenerator, self).__init__() + + self._results_dir = results_dir + self._data_prepare_config = {} + self._json_file_name = json_file_name + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._data_prepare_config.clear() + elif event_type == EventType.END_RUN: + self._data_prepare_config = fl_ctx.get_prop("client_data_dict", None) + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + data_prepare_res_dir = os.path.join(run_dir, self._results_dir) + if not os.path.exists(data_prepare_res_dir): + os.makedirs(data_prepare_res_dir) + + res_file_path = os.path.join(data_prepare_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(self._data_prepare_config, f) + + +class nnUNetPackageReportJsonGenerator(Widget): + """ + A class to generate JSON reports for nnUNet package. + + Parameters + ---------- + results_dir : str, optional + Directory where the report will be saved (default is "package_report"). + json_file_name : str, optional + Name of the JSON file to save the report (default is "package_report.json"). + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events to clear the report at the start of a run and save the report at the end of a run. + """ + + def __init__(self, results_dir="package_report", json_file_name="package_report.json"): + super(nnUNetPackageReportJsonGenerator, self).__init__() + + self._results_dir = results_dir + self._report = {} + self._json_file_name = json_file_name + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._report.clear() + elif event_type == EventType.END_RUN: + datasets = fl_ctx.get_prop("package_report", None) + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + cross_val_res_dir = os.path.join(run_dir, self._results_dir) + if not os.path.exists(cross_val_res_dir): + os.makedirs(cross_val_res_dir) + + res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(datasets, f) + + +class nnUNetPlansJsonGenerator(Widget): + """ + A class to generate JSON files for nnUNet plans. + + Parameters + ---------- + results_dir : str, optional + Directory where the preprocessing results will be stored (default is "nnUNet_preprocessing"). + json_file_name : str, optional + Name of the JSON file to be generated (default is "nnUNetPlans.json"). + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves + the plans to a JSON file at the end of a run. + """ + + def __init__(self, results_dir="nnUNet_preprocessing", json_file_name="nnUNetPlans.json"): + + super(nnUNetPlansJsonGenerator, self).__init__() + + self._results_dir = results_dir + self._nnUNetPlans = {} + self._json_file_name = json_file_name + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._nnUNetPlans.clear() + elif event_type == EventType.END_RUN: + datasets = fl_ctx.get_prop("nnunet_plans", None) + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + cross_val_res_dir = os.path.join(run_dir, self._results_dir) + if not os.path.exists(cross_val_res_dir): + os.makedirs(cross_val_res_dir) + + res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(datasets, f) + + +class nnUNetValSummaryJsonGenerator(Widget): + """ + A widget to generate a JSON summary for nnUNet validation results. + + Parameters + ---------- + results_dir : str, optional + Directory where the nnUNet training results are stored (default is "nnUNet_train"). + json_file_name : str, optional + Name of the JSON file to save the validation summary (default is "val_summary.json"). + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves + the validation summary to a JSON file at the end of a run. + """ + + def __init__(self, results_dir="nnUNet_train", json_file_name="val_summary.json"): + + super(nnUNetValSummaryJsonGenerator, self).__init__() + + self._results_dir = results_dir + self._nnUNetPlans = {} + self._json_file_name = json_file_name + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._nnUNetPlans.clear() + elif event_type == EventType.END_RUN: + datasets = fl_ctx.get_prop("val_summary_dict", None) + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + cross_val_res_dir = os.path.join(run_dir, self._results_dir) + if not os.path.exists(cross_val_res_dir): + os.makedirs(cross_val_res_dir) + + res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(datasets, f) diff --git a/monai/nvflare/nnunet_executor.py b/monai/nvflare/nnunet_executor.py new file mode 100644 index 0000000000..c00d2245aa --- /dev/null +++ b/monai/nvflare/nnunet_executor.py @@ -0,0 +1,334 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +from nvflare.apis.dxo import DXO, DataKind +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal + +from monai.nvflare.nvflare_nnunet import ( # check_host_config, + check_packages, + plan_and_preprocess, + prepare_bundle, + prepare_data_folder, + preprocess, + train, +) + + +class nnUNetExecutor(Executor): + """ + nnUNetExecutor is a class that handles the execution of various tasks related to nnUNet training and preprocessing + within the NVFlare framework. + + Parameters + ---------- + data_dir : str, optional + Directory where the data is stored. + modality_dict : dict, optional + Dictionary containing modality information. + prepare_task_name : str, optional + Name of the task for preparing the dataset. + check_client_packages_task_name : str, optional + Name of the task for checking client packages. + plan_and_preprocess_task_name : str, optional + Name of the task for planning and preprocessing. + preprocess_task_name : str, optional + Name of the task for preprocessing. + training_task_name : str, optional + Name of the task for training. + prepare_bundle_name : str, optional + Name of the task for preparing the bundle. + subfolder_suffix : str, optional + Suffix for subfolders. + dataset_format : str, optional + Format of the dataset, default is "subfolders". + patient_id_in_file_identifier : bool, optional + Whether patient ID is in file identifier, default is True. + nnunet_config : dict, optional + Configuration dictionary for nnUNet. + nnunet_root_folder : str, optional + Root folder for nnUNet. + client_name : str, optional + Name of the client. + tracking_uri : str, optional + URI for tracking. + mlflow_token : str, optional + Token for MLflow. + bundle_root : str, optional + Root directory for the bundle. + train_extra_configs : dict, optional + Extra configurations for training. + exclude_vars : list, optional + List of variables to exclude. + modality_list : list, optional + List of modalities. + + Methods + ------- + handle_event(event_type: str, fl_ctx: FLContext) + Handles events triggered during the federated learning process. + initialize(fl_ctx: FLContext) + Initializes the executor with the given federated learning context. + execute(task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable + Executes the specified task. + prepare_dataset() -> Shareable + Prepares the dataset for training. + check_packages_installed() -> Shareable + Checks if the required packages are installed. + plan_and_preprocess() -> Shareable + Plans and preprocesses the dataset. + preprocess() -> Shareable + Preprocesses the dataset. + train() -> Shareable + Trains the model. + prepare_bundle() -> Shareable + Prepares the bundle for deployment. + """ + + def __init__( + self, + data_dir=None, + modality_dict=None, + prepare_task_name="prepare", + check_client_packages_task_name="check_client_packages", + plan_and_preprocess_task_name="plan_and_preprocess", + preprocess_task_name="preprocess", + training_task_name="train", + prepare_bundle_name="prepare_bundle", + subfolder_suffix=None, + dataset_format="subfolders", + patient_id_in_file_identifier=True, + nnunet_config=None, + nnunet_root_folder=None, + client_name=None, + tracking_uri=None, + mlflow_token=None, + bundle_root=None, + modality_list=None, + train_extra_configs=None, + exclude_vars=None, + ): + super().__init__() + + self.exclude_vars = exclude_vars + self.prepare_task_name = prepare_task_name + self.data_dir = data_dir + self.subfolder_suffix = subfolder_suffix + self.patient_id_in_file_identifier = patient_id_in_file_identifier + self.dataset_format = dataset_format + self.modality_dict = modality_dict + self.nnunet_config = nnunet_config + self.nnunet_root_folder = nnunet_root_folder + self.client_name = client_name + self.tracking_uri = tracking_uri + self.mlflow_token = mlflow_token + self.check_client_packages_task_name = check_client_packages_task_name + self.plan_and_preprocess_task_name = plan_and_preprocess_task_name + self.preprocess_task_name = preprocess_task_name + self.training_task_name = training_task_name + self.prepare_bundle_name = prepare_bundle_name + self.bundle_root = bundle_root + self.train_extra_configs = train_extra_configs + self.modality_list = modality_list + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.initialize(fl_ctx) + + def initialize(self, fl_ctx: FLContext): + self.run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + self.root_dir = fl_ctx.get_engine().get_workspace().root_dir + self.custom_app_dir = fl_ctx.get_engine().get_workspace().get_app_custom_dir(fl_ctx.get_job_id()) + + with open("init_logfile_out.log", "w") as f_o: + with open("init_logfile_err.log", "w") as f_e: + subprocess.call( + [ + sys.executable, + "-m", + "pip", + "install", + "--user", + "-r", + str(Path(self.custom_app_dir).joinpath("requirements.txt")), + ], + stdout=f_o, + stderr=f_e, + ) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) + self.root_dir = fl_ctx.get_engine().get_workspace().root_dir + self.custom_app_dir = fl_ctx.get_engine().get_workspace().get_app_custom_dir(fl_ctx.get_job_id()) + try: + if task_name == self.prepare_task_name: + return self.prepare_dataset() + elif task_name == self.check_client_packages_task_name: + return self.check_packages_installed() + elif task_name == self.plan_and_preprocess_task_name: + return self.plan_and_preprocess() + elif task_name == self.preprocess_task_name: + return self.preprocess() + elif task_name == self.training_task_name: + return self.train() + elif task_name == self.prepare_bundle_name: + return self.prepare_bundle() + else: + return make_reply(ReturnCode.TASK_UNKNOWN) + except Exception as e: + self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + def prepare_dataset(self) -> Shareable: + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + data_list = prepare_data_folder( + data_dir=self.data_dir, + nnunet_root_dir=self.nnunet_root_folder, + dataset_name_or_id=self.nnunet_config["dataset_name_or_id"], + modality_dict=self.modality_dict, + experiment_name=self.nnunet_config["experiment_name"], + client_name=self.client_name, + dataset_format=self.dataset_format, + patient_id_in_file_identifier=self.patient_id_in_file_identifier, + tracking_uri=self.tracking_uri, + mlflow_token=self.mlflow_token, + subfolder_suffix=self.subfolder_suffix, + trainer_class_name=nnunet_trainer_name, + modality_list=self.modality_list, + ) + + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=data_list, meta={}) + return outgoing_dxo.to_shareable() + + def check_packages_installed(self): + packages = [ + "nvflare", + # {"package_name":'pymaia-learn',"import_name":"PyMAIA"}, + "torch", + "monai", + "numpy", + "nnunetv2", + ] + package_report = check_packages(packages) + + # host_config = check_host_config() + # package_report.update(host_config) + + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=package_report, meta={}) + + return outgoing_dxo.to_shareable() + + def plan_and_preprocess(self): + if "nnunet_plans" not in self.nnunet_config: + nnunet_plans_name = "nnUNetPlans" + else: + nnunet_plans_name = self.nnunet_config["nnunet_plans"] + + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + nnunet_plans = plan_and_preprocess( + self.nnunet_root_folder, + self.nnunet_config["dataset_name_or_id"], + self.client_name, + self.nnunet_config["experiment_name"], + self.tracking_uri, + nnunet_plans_name=nnunet_plans_name, + trainer_class_name=nnunet_trainer_name, + ) + + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=nnunet_plans, meta={}) + return outgoing_dxo.to_shareable() + + def preprocess(self): + if "nnunet_plans" not in self.nnunet_config: + nnunet_plans_name = "nnUNetPlans" + else: + nnunet_plans_name = self.nnunet_config["nnunet_plans"] + + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + nnunet_plans = preprocess( + self.nnunet_root_folder, + self.nnunet_config["dataset_name_or_id"], + nnunet_plans_file_path=Path(self.custom_app_dir).joinpath(f"{nnunet_plans_name}.json"), + trainer_class_name=nnunet_trainer_name, + ) + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=nnunet_plans, meta={}) + return outgoing_dxo.to_shareable() + + def train(self): + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + if "nnunet_plans" not in self.nnunet_config: + nnunet_plans_name = "nnUNetPlans" + else: + nnunet_plans_name = self.nnunet_config["nnunet_plans"] + + validation_summary = train( + self.nnunet_root_folder, + trainer_class_name=nnunet_trainer_name, + fold=0, + experiment_name=self.nnunet_config["experiment_name"], + client_name=self.client_name, + tracking_uri=self.tracking_uri, + nnunet_plans_name=nnunet_plans_name, + dataset_name_or_id=self.nnunet_config["dataset_name_or_id"], + run_with_bundle=True if self.bundle_root is not None else False, + bundle_root=self.bundle_root, + ) + outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=validation_summary, meta={}) + return outgoing_dxo.to_shareable() + + def prepare_bundle(self): + if "nnunet_trainer" not in self.nnunet_config: + nnunet_trainer_name = "nnUNetTrainer" + else: + nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] + + if "nnunet_plans" not in self.nnunet_config: + nnunet_plans_name = "nnUNetPlans" + else: + nnunet_plans_name = self.nnunet_config["nnunet_plans"] + + bundle_config = { + "bundle_root": self.bundle_root, + "tracking_uri": self.tracking_uri, + "mlflow_experiment_name": "FedLearning-" + self.nnunet_config["experiment_name"], + "mlflow_run_name": self.client_name, + "nnunet_plans_identifier": nnunet_plans_name, + "nnunet_trainer_class_name": nnunet_trainer_name, + } + + prepare_bundle(bundle_config, self.train_extra_configs) + + return make_reply(ReturnCode.OK) diff --git a/monai/nvflare/nvflare_generate_job_configs.py b/monai/nvflare/nvflare_generate_job_configs.py new file mode 100644 index 0000000000..130f47e309 --- /dev/null +++ b/monai/nvflare/nvflare_generate_job_configs.py @@ -0,0 +1,1085 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import subprocess +from pathlib import Path + +import yaml +from pyhocon import ConfigFactory +from pyhocon.converter import HOCONConverter + + +def prepare_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Prepare configuration files for nnUNet dataset preparation using NVFlare. + + Parameters + ---------- + clients : dict + Dictionary containing client-specific configurations. Each key is a client ID and the value is a dictionary + with the following keys: + - "data_dir": str, path to the client's data directory. + - "patient_id_in_file_identifier": str, identifier for patient ID in file. + - "modality_dict": dict, dictionary mapping modalities. + - "dataset_format": str, format of the dataset. + - "nnunet_root_folder": str, path to the nnUNet root folder. + - "client_name": str, name of the client. + - "subfolder_suffix": str, optional, suffix for subfolders. + experiment : dict + Dictionary containing experiment-specific configurations with the following keys: + - "dataset_name_or_id": str, name or ID of the dataset. + - "experiment_name": str, name of the experiment. + - "tracking_uri": str, URI for tracking. + - "mlflow_token": str, optional, token for MLflow. + root_dir : str + Root directory where the configuration files will be generated. + script_dir : str + Directory containing the scripts. + nvflare_exec : str + Path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "prepare" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Prepare nnUNet Dataset", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPrepareProcessor", "args": {}}, + {"id": "json_generator", "path": "monai.nvflare.json_generator.PrepareJsonGenerator", "args": {}}, + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "data_dir": clients[client_id]["data_dir"], + "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], + "modality_dict": clients[client_id]["modality_dict"], + "dataset_format": clients[client_id]["dataset_format"], + "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], + "nnunet_config": { + "dataset_name_or_id": experiment["dataset_name_or_id"], + "experiment_name": experiment["experiment_name"], + }, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "modality_list" in experiment: + client["executors"][0]["executor"]["args"]["modality_list"] = experiment["modality_list"] + + if "subfolder_suffix" in clients[client_id]: + client["executors"][0]["executor"]["args"]["subfolder_suffix"] = clients[client_id]["subfolder_suffix"] + if "mlflow_token" in experiment: + client["executors"][0]["executor"]["args"]["mlflow_token"] = experiment["mlflow_token"] + + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def check_client_packages_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generate job configuration files for checking client packages in an NVFlare experiment. + + Parameters + ---------- + clients : dict + A dictionary where keys are client IDs and values are client details. + experiment : str + The name of the experiment. + root_dir : str + The root directory where the configuration files will be generated. + script_dir : str + The directory containing the necessary scripts for NVFlare. + nvflare_exec : str + The NVFlare executable path. + + Returns + ------- + None + """ + task_name = "check_client_packages" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = { + "description": "Check Python Packages and Report", + "client_category": "Executor", + "controller_type": "server", + } + + meta = { + "name": f"{task_name}", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "nnunet_processor", + "path": "monai.nvflare.response_processor.nnUNetPackageReportProcessor", + "args": {}, + }, + { + "id": "json_generator", + "path": "monai.nvflare.json_generator.nnUNetPackageReportJsonGenerator", + "args": {}, + }, + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + {"tasks": [task_name], "executor": {"path": "monai.nvflare.nnunet_executor.nnUNetExecutor", "args": {}}} + ], + } + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def plan_and_preprocess_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generates and writes configuration files for the plan and preprocess task in the nnUNet experiment. + + Parameters + ---------- + clients : dict + A dictionary containing client-specific configurations. Each key is a client ID, and the value is + another dictionary with client-specific settings. + experiment : dict + A dictionary containing experiment-specific configurations such as dataset name, experiment name, + tracking URI, and optional nnUNet plans and trainer. + root_dir : str + The root directory where the configuration files will be generated. + script_dir : str + The directory containing the scripts to be used in the NVFlare job. + nvflare_exec : str + The path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "plan_and_preprocess" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Plan and Preprocess nnUNet", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPlanProcessor", "args": {}}, + {"id": "json_generator", "path": "monai.nvflare.json_generator.nnUNetPlansJsonGenerator", "args": {}}, + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "data_dir": clients[client_id]["data_dir"], + "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], + "modality_dict": clients[client_id]["modality_dict"], + "dataset_format": clients[client_id]["dataset_format"], + "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], + "nnunet_config": { + "dataset_name_or_id": experiment["dataset_name_or_id"], + "experiment_name": experiment["experiment_name"], + }, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def preprocess_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generate job configuration files for the preprocessing task in NVFlare. + + Parameters + ---------- + clients : dict + A dictionary containing client-specific configurations. Each key is a client ID, and the value is a dictionary + with the following keys: + - 'data_dir': str, path to the client's data directory. + - 'patient_id_in_file_identifier': str, identifier for patient ID in the file. + - 'modality_dict': dict, dictionary mapping modalities. + - 'dataset_format': str, format of the dataset. + - 'nnunet_root_folder': str, root folder for nnUNet. + - 'client_name': str, name of the client. + experiment : dict + A dictionary containing experiment-specific configurations with the following keys: + - 'dataset_name_or_id': str, name or ID of the dataset. + - 'experiment_name': str, name of the experiment. + - 'tracking_uri': str, URI for tracking. + - 'nnunet_plans' (optional): str, nnUNet plans. + - 'nnunet_trainer' (optional): str, nnUNet trainer. + root_dir : str + The root directory where the configuration files will be generated. + script_dir : str + The directory containing the scripts to be used in the job. + nvflare_exec : str + The NVFlare executable to be used for creating the job. + + Returns + ------- + None + """ + task_name = "preprocess" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Preprocess nnUNet", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPlanProcessor", "args": {}} + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "data_dir": clients[client_id]["data_dir"], + "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], + "modality_dict": clients[client_id]["modality_dict"], + "dataset_format": clients[client_id]["dataset_format"], + "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], + "nnunet_config": { + "dataset_name_or_id": experiment["dataset_name_or_id"], + "experiment_name": experiment["experiment_name"], + }, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def train_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generate training configuration files for nnUNet using NVFlare. + + Parameters + ---------- + clients : dict + Dictionary containing client-specific configurations. Each key is a client ID, and the value is a dictionary + with the following keys: + - 'data_dir': str, path to the client's data directory. + - 'patient_id_in_file_identifier': str, identifier for patient ID in file. + - 'modality_dict': dict, dictionary mapping modalities. + - 'dataset_format': str, format of the dataset. + - 'nnunet_root_folder': str, path to the nnUNet root folder. + - 'client_name': str, name of the client. + - 'bundle_root': str, optional, path to the bundle root directory. + experiment : dict + Dictionary containing experiment-specific configurations with the following keys: + - 'dataset_name_or_id': str, name or ID of the dataset. + - 'experiment_name': str, name of the experiment. + - 'tracking_uri': str, URI for tracking. + - 'nnunet_plans': str, optional, nnUNet plans. + - 'nnunet_trainer': str, optional, nnUNet trainer. + root_dir : str + Root directory where the configuration files will be generated. + script_dir : str + Directory containing the scripts to be used. + nvflare_exec : str + Path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "train" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Train nnUNet", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetTrainProcessor", "args": {}}, + {"id": "json_generator", "path": "monai.nvflare.json_generator.nnUNetValSummaryJsonGenerator", "args": {}}, + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 6000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "data_dir": clients[client_id]["data_dir"], + "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], + "modality_dict": clients[client_id]["modality_dict"], + "dataset_format": clients[client_id]["dataset_format"], + "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], + "nnunet_config": { + "dataset_name_or_id": experiment["dataset_name_or_id"], + "experiment_name": experiment["experiment_name"], + }, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + if "bundle_root" in clients[client_id]: + client["executors"][0]["executor"]["args"]["bundle_root"] = clients[client_id]["bundle_root"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def prepare_bundle_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Prepare the configuration files for the nnUNet bundle and generate the job configurations for NVFlare. + + Parameters + ---------- + clients : dict + A dictionary containing client information. Keys are client IDs and values are dictionaries with client details. + experiment : dict + A dictionary containing experiment details such as 'experiment_name', 'tracking_uri', and optional + configurations like 'bundle_extra_config', 'nnunet_plans', and 'nnunet_trainer'. + root_dir : str + The root directory where the configuration files and job directories will be created. + script_dir : str + The directory containing the necessary scripts for NVFlare. + nvflare_exec : str + The path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "prepare_bundle" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = {"description": "Prepare nnUNet Bundle", "client_category": "Executor", "controller_type": "server"} + + meta = { + "name": f"{task_name}_nnUNet", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": 1, + "mandatory_clients": list(clients.keys()), + } + for client_id in clients: + meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "server": {"heart_beat_timeout": 600}, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "nnunet_processor", + "path": "monai.nvflare.response_processor.nnUNetBundlePrepareProcessor", + "args": {}, + } + ], + "workflows": [ + { + "id": "broadcast_and_process", + "name": "BroadcastAndProcess", + "args": { + "processor": "nnunet_processor", + "min_responses_required": 0, + "wait_time_after_min_received": 10, + "task_name": task_name, + "timeout": 600000, + }, + } + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_id in clients: + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "components": [], + "executors": [ + { + "tasks": [task_name], + "executor": { + "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", + "args": { + "nnunet_config": {"experiment_name": experiment["experiment_name"]}, + "client_name": clients[client_id]["client_name"], + "tracking_uri": experiment["tracking_uri"], + }, + }, + } + ], + } + + if "bundle_extra_config" in experiment: + client["executors"][0]["executor"]["args"]["train_extra_configs"] = experiment["bundle_extra_config"] + if "nnunet_plans" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] + + if "nnunet_trainer" in experiment: + client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] + + if "bundle_root" in clients[client_id]: + client["executors"][0]["executor"]["args"]["bundle_root"] = clients[client_id]["bundle_root"] + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( + parents=True, exist_ok=True + ) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), + "w", + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def train_fl_config(clients, experiment, root_dir, script_dir, nvflare_exec): + """ + Generate federated learning job configurations for NVFlare. + + Parameters + ---------- + clients : dict + Dictionary containing client names and their configurations. + experiment : dict + Dictionary containing experiment parameters such as number of rounds and local epochs. + root_dir : str + Root directory where the job configurations will be saved. + script_dir : str + Directory containing the necessary scripts for NVFlare. + nvflare_exec : str + Path to the NVFlare executable. + + Returns + ------- + None + """ + task_name = "train_fl_nnunet_bundle" + Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) + + info = { + "description": "Federated Learning with nnUNet-MONAI Bundle", + "client_category": "Executor", + "controller_type": "server", + } + + meta = { + "name": f"{task_name}", + "resource_spec": {}, + "deploy_map": {f"{task_name}-server": ["server"]}, + "min_clients": len(list(clients.keys())), + "mandatory_clients": list(clients.keys()), + } + + for client_name, client_config in clients.items(): + meta["deploy_map"][f"{task_name}-{client_name}"] = [client_name] + + with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) + f.write("\n}") + + with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) + f.write("\n}") + + server = { + "format_version": 2, + "min_clients": len(list(clients.keys())), + "num_rounds": experiment["num_rounds"], + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "persistor", + "path": "monai_nvflare.monai_bundle_persistor.MonaiBundlePersistor", + "args": { + "bundle_root": experiment["server_bundle_root"], + "config_train_filename": "configs/train.yaml", + "network_def_key": "network_def_fl", + }, + }, + {"id": "shareable_generator", "name": "FullModelShareableGenerator", "args": {}}, + { + "id": "aggregator", + "name": "InTimeAccumulateWeightedAggregator", + "args": {"expected_data_kind": "WEIGHT_DIFF"}, + }, + {"id": "model_selector", "name": "IntimeModelSelector", "args": {}}, + {"id": "model_locator", "name": "PTFileModelLocator", "args": {"pt_persistor_id": "persistor"}}, + {"id": "json_generator", "name": "ValidationJsonGenerator", "args": {}}, + ], + "workflows": [ + { + "id": "scatter_gather_ctl", + "name": "ScatterAndGather", + "args": { + "min_clients": "{min_clients}", + "num_rounds": "{num_rounds}", + "start_round": experiment["start_round"], + "wait_time_after_min_received": 10, + "aggregator_id": "aggregator", + "persistor_id": "persistor", + "shareable_generator_id": "shareable_generator", + "train_task_name": "train", + "train_timeout": 0, + }, + }, + { + "id": "cross_site_model_eval", + "name": "CrossSiteModelEval", + "args": { + "model_locator_id": "model_locator", + "submit_model_timeout": 600, + "validation_timeout": 6000, + "cleanup_models": True, + }, + }, + ], + } + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) + with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) + f.write("\n}") + + for client_name, client_config in clients.items(): + client = { + "format_version": 2, + "task_result_filters": [], + "task_data_filters": [], + "executors": [ + { + "tasks": ["train", "submit_model", "validate"], + "executor": { + "id": "executor", + # "path": "monai_algo.ClientnnUNetAlgoExecutor", + "path": "monai_nvflare.client_algo_executor.ClientAlgoExecutor", + "args": {"client_algo_id": "client_algo", "key_metric": "Val_Dice"}, + }, + } + ], + "components": [ + { + "id": "client_algo", + # "path": "monai_algo.MonaiAlgonnUNet", + "path": "monai.fl.client.monai_algo.MonaiAlgo", + "args": { + "bundle_root": client_config["bundle_root"], + "config_train_filename": "configs/train.yaml", + "save_dict_key": "network_weights", + "local_epochs": experiment["local_epochs"], + "train_kwargs": {"nnunet_root_folder": client_config["nnunet_root_folder"]}, + }, + } + ], + } + + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-{client_name}").mkdir(parents=True, exist_ok=True) + with open( + Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-{client_name}", "config_fed_client.conf"), "w" + ) as f: + f.write("{\n") + f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) + f.write("\n}") + + subprocess.run( + [ + nvflare_exec, + "job", + "create", + "-j", + Path(root_dir).joinpath("jobs", task_name), + "-w", + Path(root_dir).joinpath(task_name), + "-sd", + script_dir, + "--force", + ] + ) + + +def generate_configs(client_files, experiment_file, script_dir, job_dir, nvflare_exec="nvflare"): + """ + Generate configuration files for NVFlare job. + + Parameters + ---------- + client_files : list of str + List of file paths to client configuration files. + experiment_file : str + File path to the experiment configuration file. + script_dir : str + Directory path where the scripts are located. + job_dir : str + Directory path where the job configurations will be saved. + nvflare_exec : str, optional + NVFlare executable command, by default "nvflare". + + Returns + ------- + None + """ + clients = {} + for client_id in client_files: + with open(client_id) as f: + client_name = Path(client_id).name + clients[client_name.split(".")[0]] = yaml.safe_load(f) + + with open(experiment_file) as f: + experiment = yaml.safe_load(f) + + check_client_packages_config(clients, experiment, job_dir, script_dir, nvflare_exec) + prepare_config(clients, experiment, job_dir, script_dir, nvflare_exec) + plan_and_preprocess_config(clients, experiment, job_dir, script_dir, nvflare_exec) + preprocess_config(clients, experiment, job_dir, script_dir, nvflare_exec) + train_config(clients, experiment, job_dir, script_dir, nvflare_exec) + prepare_bundle_config(clients, experiment, job_dir, script_dir, nvflare_exec) + train_fl_config(clients, experiment, job_dir, script_dir, nvflare_exec) diff --git a/monai/nvflare/nvflare_nnunet.py b/monai/nvflare/nvflare_nnunet.py new file mode 100644 index 0000000000..72dc062ccd --- /dev/null +++ b/monai/nvflare/nvflare_nnunet.py @@ -0,0 +1,695 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import logging +import multiprocessing +import os +import pathlib +import random +import re +import shutil +import subprocess +from importlib.metadata import version +from pathlib import Path + +import mlflow +import numpy as np +import pandas as pd +import psutil +import yaml + +import monai +from monai.apps.nnunet import nnUNetV2Runner +from monai.apps.nnunet.nnunet_bundle import convert_monai_bundle_to_nnunet +from monai.bundle import ConfigParser + + +def train( + nnunet_root_dir, + experiment_name, + client_name, + tracking_uri, + dataset_name_or_id, + trainer_class_name="nnUNetTrainer", + nnunet_plans_name="nnUNetPlans", + run_with_bundle=False, + fold=0, + bundle_root=None, + mlflow_token=None, +): + """ + + Train a nnUNet model and log metrics to MLflow. + + Parameters + ---------- + nnunet_root_dir : str + Root directory for nnUNet. + experiment_name : str + Name of the MLflow experiment. + client_name : str + Name of the client. + tracking_uri : str + URI for MLflow tracking server. + dataset_name_or_id : str + Name or ID of the dataset. + trainer_class_name : str, optional + Name of the nnUNet trainer class, by default "nnUNetTrainer". + nnunet_plans_name : str, optional + Name of the nnUNet plans, by default "nnUNetPlans". + run_with_bundle : bool, optional + Whether to run with MONAI bundle, by default False. + fold : int, optional + Fold number for cross-validation, by default 0. + bundle_root : str, optional + Root directory for MONAI bundle, by default None. + mlflow_token : str, optional + Token for MLflow authentication, by default None. + + Returns + ------- + dict + Dictionary containing validation summary metrics. + """ + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") + runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) + + if not run_with_bundle: + runner.train_single_model(config="3d_fullres", fold=fold) + else: + os.environ["BUNDLE_ROOT"] = bundle_root + os.environ["PYTHONPATH"] = os.environ["PYTHONPATH"] + ":" + bundle_root + monai.bundle.run( + config_file=Path(bundle_root).joinpath("configs/train.yaml"), + bundle_root=bundle_root, + nnunet_trainer_class_name=trainer_class_name, + mlflow_experiment_name=experiment_name, + mlflow_run_name="run_" + client_name, + tracking_uri=tracking_uri, + fold_id=fold, + ) + nnunet_config = {"dataset_name_or_id": dataset_name_or_id, "nnunet_trainer": trainer_class_name} + convert_monai_bundle_to_nnunet(nnunet_config, bundle_root) + runner.train_single_model(config="3d_fullres", fold=fold, val="") + + if mlflow_token is not None: + os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + try: + mlflow.create_experiment(experiment_name) + except Exception as e: + print(e) + mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) + + filter = f""" + tags."client" = "{client_name}" + """ + + runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) + + validation_summary = os.path.join( + runner.nnunet_results, + runner.dataset_name, + f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", + f"fold_{fold}", + "validation", + "summary.json", + ) + + dataset_file = os.path.join( + runner.nnunet_results, + runner.dataset_name, + f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", + "dataset.json", + ) + + with open(dataset_file, "r") as f: + dataset_dict = json.load(f) + labels = dataset_dict["labels"] + labels = {str(v): k for k, v in labels.items()} + + with open(validation_summary, "r") as f: + validation_summary_dict = json.load(f) + + if len(runs) == 0: + with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): + for label in validation_summary_dict["mean"]: + for metric in validation_summary_dict["mean"][label]: + label_name = labels[label] + mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric])) + + else: + with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): + for label in validation_summary_dict["mean"]: + for metric in validation_summary_dict["mean"][label]: + label_name = labels[label] + mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric])) + + return validation_summary_dict + + +def preprocess(nnunet_root_dir, dataset_name_or_id, nnunet_plans_file_path=None, trainer_class_name="nnUNetTrainer"): + """ + Preprocess the dataset for nnUNet training. + + Parameters + ---------- + nnunet_root_dir : str + The root directory of the nnUNet project. + dataset_name_or_id : str or int + The name or ID of the dataset to preprocess. + nnunet_plans_file_path : Path, optional + The file path to the nnUNet plans file. If None, default plans will be used. Default is None. + trainer_class_name : str, optional + The name of the trainer class to use. Default is "nnUNetTrainer". + + Returns + ------- + dict + The nnUNet plans dictionary. + """ + + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") + runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) + + nnunet_plans_name = nnunet_plans_file_path.name.split(".")[0] + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id)) + + Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name).mkdir(parents=True, exist_ok=True) + + shutil.copy( + Path(nnunet_root_dir).joinpath("nnUNet_raw_data_base", dataset_name, "dataset.json"), + Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name, "dataset.json"), + ) + if nnunet_plans_file_path is not None: + with open(nnunet_plans_file_path, "r") as f: + nnunet_plans = json.load(f) + nnunet_plans["original_dataset_name"] = nnunet_plans["dataset_name"] + nnunet_plans["dataset_name"] = dataset_name + json.dump( + nnunet_plans, + open( + Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name, f"{nnunet_plans_name}.json"), + "w", + ), + indent=4, + ) + + runner.extract_fingerprints(npfp=2, verify_dataset_integrity=True) + runner.preprocess(c=["3d_fullres"], n_proc=[2], overwrite_plans_name=nnunet_plans_name) + + return nnunet_plans + + +def plan_and_preprocess( + nnunet_root_dir, + dataset_name_or_id, + client_name, + experiment_name, + tracking_uri, + mlflow_token=None, + nnunet_plans_name="nnUNetPlans", + trainer_class_name="nnUNetTrainer", +): + """ + Plan and preprocess the dataset using nnUNetV2Runner and log the plans to MLflow. + + Parameters + ---------- + nnunet_root_dir : str + The root directory of nnUNet. + dataset_name_or_id : str or int + The name or ID of the dataset to be processed. + client_name : str + The name of the client. + experiment_name : str + The name of the MLflow experiment. + tracking_uri : str + The URI of the MLflow tracking server. + mlflow_token : str, optional + The token for MLflow authentication (default is None). + nnunet_plans_name : str, optional + The name of the nnUNet plans (default is "nnUNetPlans"). + trainer_class_name : str, optional + The name of the nnUNet trainer class (default is "nnUNetTrainer"). + + Returns + ------- + dict + The nnUNet plans as a dictionary. + """ + + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") + + runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) + + runner.plan_and_process( + npfp=2, verify_dataset_integrity=True, c=["3d_fullres"], n_proc=[2], overwrite_plans_name=nnunet_plans_name + ) + + preprocessed_folder = runner.nnunet_preprocessed + + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id)) + + with open(Path(preprocessed_folder).joinpath(f"{dataset_name}", nnunet_plans_name + ".json"), "r") as f: + nnunet_plans = json.load(f) + + if mlflow_token is not None: + os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + try: + mlflow.create_experiment(experiment_name) + except Exception as e: + print(e) + mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) + + filter = f""" + tags."client" = "{client_name}" + """ + + runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) + + if len(runs) == 0: + with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): + mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json") + + else: + with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): + mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json") + + return nnunet_plans + + +def prepare_data_folder( + data_dir, + nnunet_root_dir, + dataset_name_or_id, + modality_dict, + experiment_name, + client_name, + dataset_format, + modality_list = None, + tracking_uri=None, + mlflow_token=None, + subfolder_suffix=None, + patient_id_in_file_identifier=True, + trainer_class_name="nnUNetTrainer", +): + """ + Prepare the data folder for nnUNet training and log the data to MLflow. + + Parameters + ---------- + data_dir : str + Directory containing the dataset. + nnunet_root_dir : str + Root directory for nnUNet. + dataset_name_or_id : str + Name or ID of the dataset. + modality_dict : dict + Dictionary mapping modality IDs to file suffixes. + experiment_name : str + Name of the MLflow experiment. + client_name : str + Name of the client. + dataset_format : str + Format of the dataset. Supported formats are "subfolders", "decathlon", and "nnunet". + tracking_uri : str, optional + URI for MLflow tracking server. + modality_list : list, optional + List of modalities. Default is None. + mlflow_token : str, optional + Token for MLflow authentication. + subfolder_suffix : str, optional + Suffix for subfolder names. + patient_id_in_file_identifier : bool, optional + Whether patient ID is included in file identifier. Default is True. + trainer_class_name : str, optional + Name of the nnUNet trainer class. Default is "nnUNetTrainer". + + Returns + ------- + dict + Dictionary containing the training and testing data lists. + """ + if dataset_format == "subfolders": + if subfolder_suffix is not None: + data_list = { + "training": [ + { + modality_id: ( + str( + pathlib.Path(f.name).joinpath( + f.name[: -len(subfolder_suffix)] + modality_dict[modality_id] + ) + ) + if patient_id_in_file_identifier + else str(pathlib.Path(f.name).joinpath(modality_dict[modality_id])) + ) + for modality_id in modality_dict + } + for f in os.scandir(data_dir) + if f.is_dir() + ], + "testing": [], + } + else: + data_list = { + "training": [ + { + modality_id: ( + str(pathlib.Path(f.name).joinpath(f.name + modality_dict[modality_id])) + if patient_id_in_file_identifier + else str(pathlib.Path(f.name).joinpath(modality_dict[modality_id])) + ) + for modality_id in modality_dict + } + for f in os.scandir(data_dir) + if f.is_dir() + ], + "testing": [], + } + elif dataset_format == "decathlon" or dataset_format == "nnunet": + cases = [] + + for f in os.scandir(Path(data_dir).joinpath("imagesTr")): + if f.is_file(): + for modality_suffix in list(modality_dict.values()): + if f.name.endswith(modality_suffix) and modality_suffix != ".nii.gz": + cases.append(f.name[: -len(modality_suffix)]) + if len(np.unique(list(modality_dict.values()))) == 1 and ".nii.gz" in list(modality_dict.values()): + cases.append(f.name[: -len(".nii.gz")]) + cases = np.unique(cases) + data_list = { + "training": [ + { + modality_id: str(Path("imagesTr").joinpath(case + modality_dict[modality_id])) + for modality_id in modality_dict + if modality_id != "label" + } + for case in cases + ], + "testing": [], + } + for idx, case in enumerate(data_list["training"]): + modality_id = list(modality_dict.keys())[0] + case_id = Path(case[modality_id]).name[: -len(modality_dict[modality_id])] + data_list["training"][idx]["label"] = str(Path("labelsTr").joinpath(case_id + modality_dict["label"])) + else: + raise ValueError("Dataset format not supported") + + for idx, train_case in enumerate(data_list["training"]): + for modality_id in modality_dict: + data_list["training"][idx][modality_id + "_is_file"] = ( + Path(data_dir).joinpath(data_list["training"][idx][modality_id]).is_file() + ) + if "image" not in data_list["training"][idx] and modality_id != "label": + data_list["training"][idx]["image"] = data_list["training"][idx][modality_id] + data_list["training"][idx]["fold"] = 0 + + random.seed(42) + random.shuffle(data_list["training"]) + + data_list["testing"] = [data_list["training"][0]] + + num_folds = 5 + fold_size = len(data_list["training"]) // num_folds + for i in range(num_folds): + for j in range(fold_size): + data_list["training"][i * fold_size + j]["fold"] = i + + datalist_file = Path(data_dir).joinpath(f"{experiment_name}_folds.json") + with open(datalist_file, "w", encoding="utf-8") as f: + json.dump(data_list, f, ensure_ascii=False, indent=4) + + os.makedirs(nnunet_root_dir, exist_ok=True) + + if modality_list is None: + modality_list = [k for k in modality_dict.keys() if k != "label"] + + data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") + data_src = { + "modality": modality_list, + "dataset_name_or_id": dataset_name_or_id, + "datalist": str(datalist_file), + "dataroot": str(data_dir), + } + + ConfigParser.export_config_file(data_src, data_src_cfg) + + if dataset_format != "nnunet": + runner = nnUNetV2Runner( + input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir + ) + runner.convert_dataset() + else: + ... + + if mlflow_token is not None: + os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + try: + mlflow.create_experiment(experiment_name) + mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) + except Exception as e: + print(e) + mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) + + filter = f""" + tags."client" = "{client_name}" + """ + + runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) + + try: + if len(runs) == 0: + with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): + mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json") + else: + with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): + mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json") + except (BrokenPipeError, ConnectionError) as e: + logging.error(f"Failed to log data to MLflow: {e}") + + return data_list + + +def check_packages(packages): + """ + Check if the specified packages are installed and return a report. + + Parameters + ---------- + packages : list + A list of package names (str) or dictionaries with keys "import_name" and "package_name". + + Returns + ------- + dict + A dictionary where the keys are package names and the values are strings indicating whether + the package is installed and its version if applicable. + + Examples + -------- + >>> check_packages(["numpy", "nonexistent_package"]) + {'numpy': 'numpy 1.21.0 is installed.', 'nonexistent_package': 'nonexistent_package is not installed.'} + >>> check_packages([{"import_name": "torch", "package_name": "torch"}]) + {'torch': 'torch 1.9.0 is installed.'} + """ + report = {} + for package in packages: + try: + if isinstance(package, dict): + __import__(package["import_name"]) + package_version = version(package["package_name"]) + name = package["package_name"] + print(f"{name} {package_version} is installed.") + report[name] = f"{name} {package_version} is installed." + else: + + __import__(package) + package_version = version(package) + print(f"{package} {package_version} is installed.") + report[package] = f"{package} {package_version} is installed." + + except ImportError: + print(f"{package} is not installed.") + report[package] = f"{package} is not installed." + + return report + + +def check_host_config(): + """ + Collects and returns the host configuration details including GPU, CPU, and memory information. + + Returns + ------- + dict + A dictionary containing the following keys and their corresponding values: + - Config values from `monai.config.deviceconfig.get_config_values()` + - Optional config values from `monai.config.deviceconfig.get_optional_config_values()` + - GPU information including number of GPUs, CUDA version, cuDNN version, and GPU names and memory + - CPU core count + - Total memory in GB + - Memory usage percentage + """ + params_dict = {} + config_values = monai.config.deviceconfig.get_config_values() + for k in config_values: + params_dict[re.sub("[()]", " ", str(k))] = config_values[k] + optional_config_values = monai.config.deviceconfig.get_optional_config_values() + + for k in optional_config_values: + params_dict[re.sub("[()]", " ", str(k))] = optional_config_values[k] + + gpu_info = monai.config.deviceconfig.get_gpu_info() + allowed_keys = ["Num GPUs", "Has Cuda", "CUDA Version", "cuDNN enabled", "cuDNN Version"] + for i in range(gpu_info["Num GPUs"]): + allowed_keys.append(f"GPU {i} Name") + allowed_keys.append(f"GPU {i} Total memory GB ") + + for k in gpu_info: + if re.sub("[()]", " ", str(k)) in allowed_keys: + params_dict[re.sub("[()]", " ", str(k))] = str(gpu_info[k]) + + with open("nvidia-smi.log", "w") as f_e: + subprocess.run("nvidia-smi", stderr=f_e, stdout=f_e) + + params_dict["CPU_Cores"] = multiprocessing.cpu_count() + + vm = psutil.virtual_memory() + + params_dict["Total Memory"] = vm.total / (1024 * 1024 * 1024) + params_dict["Memory Used %"] = vm.percent + + return params_dict + + +def prepare_bundle(bundle_config, train_extra_configs=None): + """ + Prepare the bundle configuration for training and evaluation. + + Parameters + ---------- + bundle_config : dict + Dictionary containing the bundle configuration. Expected keys are: + - "bundle_root": str, root directory of the bundle. + - "tracking_uri": str, URI for tracking. + - "mlflow_experiment_name": str, name of the MLflow experiment. + - "mlflow_run_name": str, name of the MLflow run. + - "nnunet_plans_identifier": str, optional, identifier for nnUNet plans. + - "nnunet_trainer_class_name": str, optional, class name for nnUNet trainer. + train_extra_configs : dict, optional + Additional configurations for training. If provided, expected keys are: + - "resume_epoch": int, epoch to resume training from. + - Any other key-value pairs to be added to the training configuration. + + Returns + ------- + None + """ + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml")) as f: + train_config = yaml.safe_load(f) + train_config["bundle_root"] = bundle_config["bundle_root"] + train_config["tracking_uri"] = bundle_config["tracking_uri"] + train_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"] + train_config["mlflow_run_name"] = bundle_config["mlflow_run_name"] + + train_config["data_src_cfg"] = "$@nnunet_root_folder+'/data_src_cfg.yaml'" + train_config["runner"] = { + "_target_": "nnUNetV2Runner", + "input_config": "$@data_src_cfg", + "trainer_class_name": "@nnunet_trainer_class_name", + "work_dir": "@nnunet_root_folder", + } + + train_config["network"] = "$@nnunet_trainer.network._orig_mod" + + train_handlers = train_config["train_handlers"]["handlers"] + + for idx, handler in enumerate(train_handlers): + if handler["_target_"] == "ValidationHandler": + train_handlers.pop(idx) + break + + train_config["train_handlers"]["handlers"] = train_handlers + + if train_extra_configs is not None and "resume_epoch" in train_extra_configs: + resume_epoch = train_extra_configs["resume_epoch"] + train_config["initialize"] = [ + "$monai.utils.set_determinism(seed=123)", + "$@runner.dataset_name_or_id", + f"$src.trainer.reload_checkpoint(@train#trainer, {resume_epoch}, @iterations, @ckpt_dir, @lr_scheduler)", + ] + else: + train_config["initialize"] = ["$monai.utils.set_determinism(seed=123)", "$@runner.dataset_name_or_id"] + + if "Val_Dice" in train_config["val_key_metric"]: + train_config["val_key_metric"] = {"Val_Dice_Local": train_config["val_key_metric"]["Val_Dice"]} + + if "Val_Dice_per_class" in train_config["val_additional_metrics"]: + train_config["val_additional_metrics"] = { + "Val_Dice_per_class_Local": train_config["val_additional_metrics"]["Val_Dice_per_class"] + } + if "nnunet_plans_identifier" in bundle_config: + train_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"] + + if "nnunet_trainer_class_name" in bundle_config: + train_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"] + + if train_extra_configs is not None: + for key in train_extra_configs: + train_config[key] = train_extra_configs[key] + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.json"), "w") as f: + json.dump(train_config, f) + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml"), "w") as f: + yaml.dump(train_config, f) + + if not Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml").exists(): + shutil.copy( + Path(bundle_config["bundle_root"]).joinpath("nnUNet", "evaluator", "evaluator.yaml"), + Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), + ) + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml")) as f: + evaluate_config = yaml.safe_load(f) + evaluate_config["bundle_root"] = bundle_config["bundle_root"] + + evaluate_config["tracking_uri"] = bundle_config["tracking_uri"] + evaluate_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"] + evaluate_config["mlflow_run_name"] = bundle_config["mlflow_run_name"] + + if "nnunet_plans_identifier" in bundle_config: + evaluate_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"] + if "nnunet_trainer_class_name" in bundle_config: + evaluate_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"] + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.json"), "w") as f: + json.dump(evaluate_config, f) + + with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), "w") as f: + yaml.dump(evaluate_config, f) diff --git a/monai/nvflare/response_processor.py b/monai/nvflare/response_processor.py new file mode 100644 index 0000000000..a02d307220 --- /dev/null +++ b/monai/nvflare/response_processor.py @@ -0,0 +1,342 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from nvflare.apis.client import Client +from nvflare.apis.dxo import DataKind, from_shareable +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_common.abstract.response_processor import ResponseProcessor + + +class nnUNetPrepareProcessor(ResponseProcessor): + """ + A processor class for preparing nnUNet data in a federated learning context. + + Methods + ------- + __init__(): + Initializes the nnUNetPrepareProcessor with an empty data dictionary. + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable: + Creates and returns a Shareable object for the given task name. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + Processes the response from a client. Validates the response and updates the data dictionary if valid. + final_process(fl_ctx: FLContext) -> bool: + Finalizes the processing by setting the client data dictionary in the federated learning context. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + self.data_dict = {} + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + data_dict = dxo.data + + if not data_dict: + self.log_error(fl_ctx, f"No dataset_dict found from client {client.name}") + return False + + self.data_dict[client.name] = data_dict + + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + if not self.data_dict: + self.log_error(fl_ctx, "no data_prepare_dict from clients") + return False + + # must set sticky to True so other controllers can get it! + fl_ctx.set_prop("client_data_dict", self.data_dict, private=True, sticky=True) + return True + + +class nnUNetPackageReportProcessor(ResponseProcessor): + """ + A processor for handling nnUNet package reports in a federated learning context. + + Attributes + ---------- + package_report : dict + A dictionary to store package reports from clients. + + Methods + ------- + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable + Creates task data for a given task name and federated learning context. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool + Processes the response from a client for a given task name and federated learning context. + final_process(fl_ctx: FLContext) -> bool + Final processing step to handle the collected package reports. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + self.package_report = {} + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + package_report = dxo.data + + if not package_report: + self.log_error(fl_ctx, f"No package_report found from client {client.name}") + return False + + self.package_report[client.name] = package_report + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + if not self.package_report: + self.log_error(fl_ctx, "no plan_dict from client") + return False + + # must set sticky to True so other controllers can get it! + fl_ctx.set_prop("package_report", self.package_report, private=True, sticky=True) + return True + + +class nnUNetPlanProcessor(ResponseProcessor): + """ + nnUNetPlanProcessor is a class that processes responses from clients in a federated learning context. + It inherits from the ResponseProcessor class and is responsible for handling and validating the + responses, extracting the necessary data, and storing it for further use. + + Attributes + ---------- + plan_dict : dict + A dictionary to store the plan data received from clients. + + Methods + ------- + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable + Creates and returns a Shareable object for the given task name. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool + Processes the response from a client, validates it, and stores the plan data if valid. + final_process(fl_ctx: FLContext) -> bool + Finalizes the processing by setting the plan data in the federated learning context. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + self.plan_dict = {} + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + plan_dict = dxo.data + + if not plan_dict: + self.log_error(fl_ctx, f"No plan_dict found from client {client.name}") + return False + + self.plan_dict[client.name] = plan_dict + + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + if not self.plan_dict: + self.log_error(fl_ctx, "no plan_dict from client") + return False + + # must set sticky to True so other controllers can get it! + fl_ctx.set_prop("nnunet_plans", self.plan_dict, private=True, sticky=True) + return True + + +class nnUNetTrainProcessor(ResponseProcessor): + """ + A processor class for handling training responses in the nnUNet framework. + + Attributes + ---------- + val_summary_dict : dict + A dictionary to store validation summaries from clients. + Methods + ------- + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable + Creates task data for a given task name and FLContext. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool + Processes the response from a client for a given task name and FLContext. + final_process(fl_ctx: FLContext) -> bool + Final processing step to handle the collected validation summaries. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + self.val_summary_dict = {} + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + val_summary_dict = dxo.data + + if not val_summary_dict: + self.log_error(fl_ctx, f"No val_summary_dict found from client {client.name}") + return False + + self.val_summary_dict[client.name] = val_summary_dict + + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + if not self.val_summary_dict: + self.log_error(fl_ctx, "no val_summary_dict from client") + return False + + # must set sticky to True so other controllers can get it! + fl_ctx.set_prop("val_summary_dict", self.val_summary_dict, private=True, sticky=True) + return True + + +class nnUNetBundlePrepareProcessor(ResponseProcessor): + """ + A processor class for preparing nnUNet bundles in a federated learning context. + + Methods + ------- + __init__(): + Initializes the nnUNetBundlePrepareProcessor instance. + create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable: + Creates task data for a given task name and federated learning context. + process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + Processes the response from a client and validates it. + final_process(fl_ctx: FLContext) -> bool: + Final processing step after all client responses have been processed. + """ + + def __init__(self): + ResponseProcessor.__init__(self) + + def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: + return Shareable() + + def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: + if not isinstance(response, Shareable): + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", + ) + return False + + try: + dxo = from_shareable(response) + + except Exception: + self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") + return False + + if dxo.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"bad response from client {client.name}: " + f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", + ) + return False + + return True + + def final_process(self, fl_ctx: FLContext) -> bool: + + return True From 3b13218c349f123e6b44a5c02310d496afc216db Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Tue, 1 Apr 2025 13:53:35 +0200 Subject: [PATCH 61/67] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ad394ce807..77221491e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -torch>=2.3.0; sys_platform != 'win32' +torch>=2.3.0a0; sys_platform != 'win32' torch>=2.4.1; sys_platform == 'win32' numpy>=1.24,<3.0 From fbf6105cf884438fdd8f9059809b98bcfd48680e Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Wed, 2 Apr 2025 06:30:20 +0000 Subject: [PATCH 62/67] Add nnunet_root_folder parameter to train function --- monai/nvflare/nvflare_nnunet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/nvflare/nvflare_nnunet.py b/monai/nvflare/nvflare_nnunet.py index 72dc062ccd..3325d26ec8 100644 --- a/monai/nvflare/nvflare_nnunet.py +++ b/monai/nvflare/nvflare_nnunet.py @@ -97,6 +97,7 @@ def train( mlflow_run_name="run_" + client_name, tracking_uri=tracking_uri, fold_id=fold, + nnunet_root_folder=nnunet_root_dir, ) nnunet_config = {"dataset_name_or_id": dataset_name_or_id, "nnunet_trainer": trainer_class_name} convert_monai_bundle_to_nnunet(nnunet_config, bundle_root) From d1035ca2a4ed9bf17badf56e5fcc682209a3b9bb Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Wed, 2 Apr 2025 14:36:44 +0000 Subject: [PATCH 63/67] ``` No code changes detected. ``` --- monai/apps/nnunet/nnunet_bundle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index bdc6bedc6f..13040f0306 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -124,7 +124,6 @@ def get_nnunet_trainer( fold, trainer_class_name, plans_identifier, - use_compressed_data, device=torch.device(device), ) if disable_checkpointing: From 47798af8627d66daf5ebd962764aff81d723da87 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Wed, 2 Apr 2025 14:41:58 +0000 Subject: [PATCH 64/67] Remove conditional print statement for torch.compile in nnUNetWrapper --- monai/apps/nnunet/nnunet_bundle.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 13040f0306..14ed2636c5 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -234,9 +234,7 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name predictor.trainer_name = trainer_name # type: ignore predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes # type: ignore predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore - if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")): - print("Using torch.compile") - # End Block + self.network_weights = self.predictor.network # type: ignore def forward(self, x: MetaTensor) -> MetaTensor: From 0578b22e5a49fdd97997529e1270fc14312c7b96 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Wed, 2 Apr 2025 14:45:04 +0000 Subject: [PATCH 65/67] Remove unused nvflare module files --- monai/nvflare/__init__.py | 10 - monai/nvflare/json_generator.py | 179 --- monai/nvflare/nnunet_executor.py | 334 ----- monai/nvflare/nvflare_generate_job_configs.py | 1085 ----------------- monai/nvflare/nvflare_nnunet.py | 696 ----------- monai/nvflare/response_processor.py | 342 ------ 6 files changed, 2646 deletions(-) delete mode 100644 monai/nvflare/__init__.py delete mode 100644 monai/nvflare/json_generator.py delete mode 100644 monai/nvflare/nnunet_executor.py delete mode 100644 monai/nvflare/nvflare_generate_job_configs.py delete mode 100644 monai/nvflare/nvflare_nnunet.py delete mode 100644 monai/nvflare/response_processor.py diff --git a/monai/nvflare/__init__.py b/monai/nvflare/__init__.py deleted file mode 100644 index 1e97f89407..0000000000 --- a/monai/nvflare/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/monai/nvflare/json_generator.py b/monai/nvflare/json_generator.py deleted file mode 100644 index 9326a35837..0000000000 --- a/monai/nvflare/json_generator.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import json -import os.path - -from nvflare.apis.event_type import EventType -from nvflare.apis.fl_context import FLContext -from nvflare.widgets.widget import Widget - - -class PrepareJsonGenerator(Widget): - """ - A widget class to prepare and generate a JSON file containing data preparation configurations. - - Parameters - ---------- - results_dir : str, optional - The directory where the results will be stored (default is "prepare"). - json_file_name : str, optional - The name of the JSON file to be generated (default is "data_dict.json"). - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events during the federated learning process. Clears the data preparation configuration - at the start of a run and saves the configuration to a JSON file at the end of a run. - """ - - def __init__(self, results_dir="prepare", json_file_name="data_dict.json"): - super(PrepareJsonGenerator, self).__init__() - - self._results_dir = results_dir - self._data_prepare_config = {} - self._json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self._data_prepare_config.clear() - elif event_type == EventType.END_RUN: - self._data_prepare_config = fl_ctx.get_prop("client_data_dict", None) - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - data_prepare_res_dir = os.path.join(run_dir, self._results_dir) - if not os.path.exists(data_prepare_res_dir): - os.makedirs(data_prepare_res_dir) - - res_file_path = os.path.join(data_prepare_res_dir, self._json_file_name) - with open(res_file_path, "w") as f: - json.dump(self._data_prepare_config, f) - - -class nnUNetPackageReportJsonGenerator(Widget): - """ - A class to generate JSON reports for nnUNet package. - - Parameters - ---------- - results_dir : str, optional - Directory where the report will be saved (default is "package_report"). - json_file_name : str, optional - Name of the JSON file to save the report (default is "package_report.json"). - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events to clear the report at the start of a run and save the report at the end of a run. - """ - - def __init__(self, results_dir="package_report", json_file_name="package_report.json"): - super(nnUNetPackageReportJsonGenerator, self).__init__() - - self._results_dir = results_dir - self._report = {} - self._json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self._report.clear() - elif event_type == EventType.END_RUN: - datasets = fl_ctx.get_prop("package_report", None) - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - cross_val_res_dir = os.path.join(run_dir, self._results_dir) - if not os.path.exists(cross_val_res_dir): - os.makedirs(cross_val_res_dir) - - res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) - with open(res_file_path, "w") as f: - json.dump(datasets, f) - - -class nnUNetPlansJsonGenerator(Widget): - """ - A class to generate JSON files for nnUNet plans. - - Parameters - ---------- - results_dir : str, optional - Directory where the preprocessing results will be stored (default is "nnUNet_preprocessing"). - json_file_name : str, optional - Name of the JSON file to be generated (default is "nnUNetPlans.json"). - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves - the plans to a JSON file at the end of a run. - """ - - def __init__(self, results_dir="nnUNet_preprocessing", json_file_name="nnUNetPlans.json"): - - super(nnUNetPlansJsonGenerator, self).__init__() - - self._results_dir = results_dir - self._nnUNetPlans = {} - self._json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self._nnUNetPlans.clear() - elif event_type == EventType.END_RUN: - datasets = fl_ctx.get_prop("nnunet_plans", None) - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - cross_val_res_dir = os.path.join(run_dir, self._results_dir) - if not os.path.exists(cross_val_res_dir): - os.makedirs(cross_val_res_dir) - - res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) - with open(res_file_path, "w") as f: - json.dump(datasets, f) - - -class nnUNetValSummaryJsonGenerator(Widget): - """ - A widget to generate a JSON summary for nnUNet validation results. - - Parameters - ---------- - results_dir : str, optional - Directory where the nnUNet training results are stored (default is "nnUNet_train"). - json_file_name : str, optional - Name of the JSON file to save the validation summary (default is "val_summary.json"). - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves - the validation summary to a JSON file at the end of a run. - """ - - def __init__(self, results_dir="nnUNet_train", json_file_name="val_summary.json"): - - super(nnUNetValSummaryJsonGenerator, self).__init__() - - self._results_dir = results_dir - self._nnUNetPlans = {} - self._json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self._nnUNetPlans.clear() - elif event_type == EventType.END_RUN: - datasets = fl_ctx.get_prop("val_summary_dict", None) - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - cross_val_res_dir = os.path.join(run_dir, self._results_dir) - if not os.path.exists(cross_val_res_dir): - os.makedirs(cross_val_res_dir) - - res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) - with open(res_file_path, "w") as f: - json.dump(datasets, f) diff --git a/monai/nvflare/nnunet_executor.py b/monai/nvflare/nnunet_executor.py deleted file mode 100644 index c00d2245aa..0000000000 --- a/monai/nvflare/nnunet_executor.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import subprocess -import sys -from pathlib import Path - -from nvflare.apis.dxo import DXO, DataKind -from nvflare.apis.event_type import EventType -from nvflare.apis.executor import Executor -from nvflare.apis.fl_constant import ReturnCode -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable, make_reply -from nvflare.apis.signal import Signal - -from monai.nvflare.nvflare_nnunet import ( # check_host_config, - check_packages, - plan_and_preprocess, - prepare_bundle, - prepare_data_folder, - preprocess, - train, -) - - -class nnUNetExecutor(Executor): - """ - nnUNetExecutor is a class that handles the execution of various tasks related to nnUNet training and preprocessing - within the NVFlare framework. - - Parameters - ---------- - data_dir : str, optional - Directory where the data is stored. - modality_dict : dict, optional - Dictionary containing modality information. - prepare_task_name : str, optional - Name of the task for preparing the dataset. - check_client_packages_task_name : str, optional - Name of the task for checking client packages. - plan_and_preprocess_task_name : str, optional - Name of the task for planning and preprocessing. - preprocess_task_name : str, optional - Name of the task for preprocessing. - training_task_name : str, optional - Name of the task for training. - prepare_bundle_name : str, optional - Name of the task for preparing the bundle. - subfolder_suffix : str, optional - Suffix for subfolders. - dataset_format : str, optional - Format of the dataset, default is "subfolders". - patient_id_in_file_identifier : bool, optional - Whether patient ID is in file identifier, default is True. - nnunet_config : dict, optional - Configuration dictionary for nnUNet. - nnunet_root_folder : str, optional - Root folder for nnUNet. - client_name : str, optional - Name of the client. - tracking_uri : str, optional - URI for tracking. - mlflow_token : str, optional - Token for MLflow. - bundle_root : str, optional - Root directory for the bundle. - train_extra_configs : dict, optional - Extra configurations for training. - exclude_vars : list, optional - List of variables to exclude. - modality_list : list, optional - List of modalities. - - Methods - ------- - handle_event(event_type: str, fl_ctx: FLContext) - Handles events triggered during the federated learning process. - initialize(fl_ctx: FLContext) - Initializes the executor with the given federated learning context. - execute(task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable - Executes the specified task. - prepare_dataset() -> Shareable - Prepares the dataset for training. - check_packages_installed() -> Shareable - Checks if the required packages are installed. - plan_and_preprocess() -> Shareable - Plans and preprocesses the dataset. - preprocess() -> Shareable - Preprocesses the dataset. - train() -> Shareable - Trains the model. - prepare_bundle() -> Shareable - Prepares the bundle for deployment. - """ - - def __init__( - self, - data_dir=None, - modality_dict=None, - prepare_task_name="prepare", - check_client_packages_task_name="check_client_packages", - plan_and_preprocess_task_name="plan_and_preprocess", - preprocess_task_name="preprocess", - training_task_name="train", - prepare_bundle_name="prepare_bundle", - subfolder_suffix=None, - dataset_format="subfolders", - patient_id_in_file_identifier=True, - nnunet_config=None, - nnunet_root_folder=None, - client_name=None, - tracking_uri=None, - mlflow_token=None, - bundle_root=None, - modality_list=None, - train_extra_configs=None, - exclude_vars=None, - ): - super().__init__() - - self.exclude_vars = exclude_vars - self.prepare_task_name = prepare_task_name - self.data_dir = data_dir - self.subfolder_suffix = subfolder_suffix - self.patient_id_in_file_identifier = patient_id_in_file_identifier - self.dataset_format = dataset_format - self.modality_dict = modality_dict - self.nnunet_config = nnunet_config - self.nnunet_root_folder = nnunet_root_folder - self.client_name = client_name - self.tracking_uri = tracking_uri - self.mlflow_token = mlflow_token - self.check_client_packages_task_name = check_client_packages_task_name - self.plan_and_preprocess_task_name = plan_and_preprocess_task_name - self.preprocess_task_name = preprocess_task_name - self.training_task_name = training_task_name - self.prepare_bundle_name = prepare_bundle_name - self.bundle_root = bundle_root - self.train_extra_configs = train_extra_configs - self.modality_list = modality_list - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self.initialize(fl_ctx) - - def initialize(self, fl_ctx: FLContext): - self.run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - self.root_dir = fl_ctx.get_engine().get_workspace().root_dir - self.custom_app_dir = fl_ctx.get_engine().get_workspace().get_app_custom_dir(fl_ctx.get_job_id()) - - with open("init_logfile_out.log", "w") as f_o: - with open("init_logfile_err.log", "w") as f_e: - subprocess.call( - [ - sys.executable, - "-m", - "pip", - "install", - "--user", - "-r", - str(Path(self.custom_app_dir).joinpath("requirements.txt")), - ], - stdout=f_o, - stderr=f_e, - ) - - def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: - self.run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) - self.root_dir = fl_ctx.get_engine().get_workspace().root_dir - self.custom_app_dir = fl_ctx.get_engine().get_workspace().get_app_custom_dir(fl_ctx.get_job_id()) - try: - if task_name == self.prepare_task_name: - return self.prepare_dataset() - elif task_name == self.check_client_packages_task_name: - return self.check_packages_installed() - elif task_name == self.plan_and_preprocess_task_name: - return self.plan_and_preprocess() - elif task_name == self.preprocess_task_name: - return self.preprocess() - elif task_name == self.training_task_name: - return self.train() - elif task_name == self.prepare_bundle_name: - return self.prepare_bundle() - else: - return make_reply(ReturnCode.TASK_UNKNOWN) - except Exception as e: - self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.") - return make_reply(ReturnCode.EXECUTION_EXCEPTION) - - def prepare_dataset(self) -> Shareable: - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - data_list = prepare_data_folder( - data_dir=self.data_dir, - nnunet_root_dir=self.nnunet_root_folder, - dataset_name_or_id=self.nnunet_config["dataset_name_or_id"], - modality_dict=self.modality_dict, - experiment_name=self.nnunet_config["experiment_name"], - client_name=self.client_name, - dataset_format=self.dataset_format, - patient_id_in_file_identifier=self.patient_id_in_file_identifier, - tracking_uri=self.tracking_uri, - mlflow_token=self.mlflow_token, - subfolder_suffix=self.subfolder_suffix, - trainer_class_name=nnunet_trainer_name, - modality_list=self.modality_list, - ) - - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=data_list, meta={}) - return outgoing_dxo.to_shareable() - - def check_packages_installed(self): - packages = [ - "nvflare", - # {"package_name":'pymaia-learn',"import_name":"PyMAIA"}, - "torch", - "monai", - "numpy", - "nnunetv2", - ] - package_report = check_packages(packages) - - # host_config = check_host_config() - # package_report.update(host_config) - - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=package_report, meta={}) - - return outgoing_dxo.to_shareable() - - def plan_and_preprocess(self): - if "nnunet_plans" not in self.nnunet_config: - nnunet_plans_name = "nnUNetPlans" - else: - nnunet_plans_name = self.nnunet_config["nnunet_plans"] - - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - nnunet_plans = plan_and_preprocess( - self.nnunet_root_folder, - self.nnunet_config["dataset_name_or_id"], - self.client_name, - self.nnunet_config["experiment_name"], - self.tracking_uri, - nnunet_plans_name=nnunet_plans_name, - trainer_class_name=nnunet_trainer_name, - ) - - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=nnunet_plans, meta={}) - return outgoing_dxo.to_shareable() - - def preprocess(self): - if "nnunet_plans" not in self.nnunet_config: - nnunet_plans_name = "nnUNetPlans" - else: - nnunet_plans_name = self.nnunet_config["nnunet_plans"] - - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - nnunet_plans = preprocess( - self.nnunet_root_folder, - self.nnunet_config["dataset_name_or_id"], - nnunet_plans_file_path=Path(self.custom_app_dir).joinpath(f"{nnunet_plans_name}.json"), - trainer_class_name=nnunet_trainer_name, - ) - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=nnunet_plans, meta={}) - return outgoing_dxo.to_shareable() - - def train(self): - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - if "nnunet_plans" not in self.nnunet_config: - nnunet_plans_name = "nnUNetPlans" - else: - nnunet_plans_name = self.nnunet_config["nnunet_plans"] - - validation_summary = train( - self.nnunet_root_folder, - trainer_class_name=nnunet_trainer_name, - fold=0, - experiment_name=self.nnunet_config["experiment_name"], - client_name=self.client_name, - tracking_uri=self.tracking_uri, - nnunet_plans_name=nnunet_plans_name, - dataset_name_or_id=self.nnunet_config["dataset_name_or_id"], - run_with_bundle=True if self.bundle_root is not None else False, - bundle_root=self.bundle_root, - ) - outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=validation_summary, meta={}) - return outgoing_dxo.to_shareable() - - def prepare_bundle(self): - if "nnunet_trainer" not in self.nnunet_config: - nnunet_trainer_name = "nnUNetTrainer" - else: - nnunet_trainer_name = self.nnunet_config["nnunet_trainer"] - - if "nnunet_plans" not in self.nnunet_config: - nnunet_plans_name = "nnUNetPlans" - else: - nnunet_plans_name = self.nnunet_config["nnunet_plans"] - - bundle_config = { - "bundle_root": self.bundle_root, - "tracking_uri": self.tracking_uri, - "mlflow_experiment_name": "FedLearning-" + self.nnunet_config["experiment_name"], - "mlflow_run_name": self.client_name, - "nnunet_plans_identifier": nnunet_plans_name, - "nnunet_trainer_class_name": nnunet_trainer_name, - } - - prepare_bundle(bundle_config, self.train_extra_configs) - - return make_reply(ReturnCode.OK) diff --git a/monai/nvflare/nvflare_generate_job_configs.py b/monai/nvflare/nvflare_generate_job_configs.py deleted file mode 100644 index 130f47e309..0000000000 --- a/monai/nvflare/nvflare_generate_job_configs.py +++ /dev/null @@ -1,1085 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import subprocess -from pathlib import Path - -import yaml -from pyhocon import ConfigFactory -from pyhocon.converter import HOCONConverter - - -def prepare_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Prepare configuration files for nnUNet dataset preparation using NVFlare. - - Parameters - ---------- - clients : dict - Dictionary containing client-specific configurations. Each key is a client ID and the value is a dictionary - with the following keys: - - "data_dir": str, path to the client's data directory. - - "patient_id_in_file_identifier": str, identifier for patient ID in file. - - "modality_dict": dict, dictionary mapping modalities. - - "dataset_format": str, format of the dataset. - - "nnunet_root_folder": str, path to the nnUNet root folder. - - "client_name": str, name of the client. - - "subfolder_suffix": str, optional, suffix for subfolders. - experiment : dict - Dictionary containing experiment-specific configurations with the following keys: - - "dataset_name_or_id": str, name or ID of the dataset. - - "experiment_name": str, name of the experiment. - - "tracking_uri": str, URI for tracking. - - "mlflow_token": str, optional, token for MLflow. - root_dir : str - Root directory where the configuration files will be generated. - script_dir : str - Directory containing the scripts. - nvflare_exec : str - Path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "prepare" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Prepare nnUNet Dataset", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPrepareProcessor", "args": {}}, - {"id": "json_generator", "path": "monai.nvflare.json_generator.PrepareJsonGenerator", "args": {}}, - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "data_dir": clients[client_id]["data_dir"], - "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], - "modality_dict": clients[client_id]["modality_dict"], - "dataset_format": clients[client_id]["dataset_format"], - "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], - "nnunet_config": { - "dataset_name_or_id": experiment["dataset_name_or_id"], - "experiment_name": experiment["experiment_name"], - }, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "modality_list" in experiment: - client["executors"][0]["executor"]["args"]["modality_list"] = experiment["modality_list"] - - if "subfolder_suffix" in clients[client_id]: - client["executors"][0]["executor"]["args"]["subfolder_suffix"] = clients[client_id]["subfolder_suffix"] - if "mlflow_token" in experiment: - client["executors"][0]["executor"]["args"]["mlflow_token"] = experiment["mlflow_token"] - - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def check_client_packages_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generate job configuration files for checking client packages in an NVFlare experiment. - - Parameters - ---------- - clients : dict - A dictionary where keys are client IDs and values are client details. - experiment : str - The name of the experiment. - root_dir : str - The root directory where the configuration files will be generated. - script_dir : str - The directory containing the necessary scripts for NVFlare. - nvflare_exec : str - The NVFlare executable path. - - Returns - ------- - None - """ - task_name = "check_client_packages" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = { - "description": "Check Python Packages and Report", - "client_category": "Executor", - "controller_type": "server", - } - - meta = { - "name": f"{task_name}", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "nnunet_processor", - "path": "monai.nvflare.response_processor.nnUNetPackageReportProcessor", - "args": {}, - }, - { - "id": "json_generator", - "path": "monai.nvflare.json_generator.nnUNetPackageReportJsonGenerator", - "args": {}, - }, - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - {"tasks": [task_name], "executor": {"path": "monai.nvflare.nnunet_executor.nnUNetExecutor", "args": {}}} - ], - } - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def plan_and_preprocess_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generates and writes configuration files for the plan and preprocess task in the nnUNet experiment. - - Parameters - ---------- - clients : dict - A dictionary containing client-specific configurations. Each key is a client ID, and the value is - another dictionary with client-specific settings. - experiment : dict - A dictionary containing experiment-specific configurations such as dataset name, experiment name, - tracking URI, and optional nnUNet plans and trainer. - root_dir : str - The root directory where the configuration files will be generated. - script_dir : str - The directory containing the scripts to be used in the NVFlare job. - nvflare_exec : str - The path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "plan_and_preprocess" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Plan and Preprocess nnUNet", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPlanProcessor", "args": {}}, - {"id": "json_generator", "path": "monai.nvflare.json_generator.nnUNetPlansJsonGenerator", "args": {}}, - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "data_dir": clients[client_id]["data_dir"], - "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], - "modality_dict": clients[client_id]["modality_dict"], - "dataset_format": clients[client_id]["dataset_format"], - "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], - "nnunet_config": { - "dataset_name_or_id": experiment["dataset_name_or_id"], - "experiment_name": experiment["experiment_name"], - }, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def preprocess_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generate job configuration files for the preprocessing task in NVFlare. - - Parameters - ---------- - clients : dict - A dictionary containing client-specific configurations. Each key is a client ID, and the value is a dictionary - with the following keys: - - 'data_dir': str, path to the client's data directory. - - 'patient_id_in_file_identifier': str, identifier for patient ID in the file. - - 'modality_dict': dict, dictionary mapping modalities. - - 'dataset_format': str, format of the dataset. - - 'nnunet_root_folder': str, root folder for nnUNet. - - 'client_name': str, name of the client. - experiment : dict - A dictionary containing experiment-specific configurations with the following keys: - - 'dataset_name_or_id': str, name or ID of the dataset. - - 'experiment_name': str, name of the experiment. - - 'tracking_uri': str, URI for tracking. - - 'nnunet_plans' (optional): str, nnUNet plans. - - 'nnunet_trainer' (optional): str, nnUNet trainer. - root_dir : str - The root directory where the configuration files will be generated. - script_dir : str - The directory containing the scripts to be used in the job. - nvflare_exec : str - The NVFlare executable to be used for creating the job. - - Returns - ------- - None - """ - task_name = "preprocess" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Preprocess nnUNet", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetPlanProcessor", "args": {}} - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "data_dir": clients[client_id]["data_dir"], - "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], - "modality_dict": clients[client_id]["modality_dict"], - "dataset_format": clients[client_id]["dataset_format"], - "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], - "nnunet_config": { - "dataset_name_or_id": experiment["dataset_name_or_id"], - "experiment_name": experiment["experiment_name"], - }, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def train_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generate training configuration files for nnUNet using NVFlare. - - Parameters - ---------- - clients : dict - Dictionary containing client-specific configurations. Each key is a client ID, and the value is a dictionary - with the following keys: - - 'data_dir': str, path to the client's data directory. - - 'patient_id_in_file_identifier': str, identifier for patient ID in file. - - 'modality_dict': dict, dictionary mapping modalities. - - 'dataset_format': str, format of the dataset. - - 'nnunet_root_folder': str, path to the nnUNet root folder. - - 'client_name': str, name of the client. - - 'bundle_root': str, optional, path to the bundle root directory. - experiment : dict - Dictionary containing experiment-specific configurations with the following keys: - - 'dataset_name_or_id': str, name or ID of the dataset. - - 'experiment_name': str, name of the experiment. - - 'tracking_uri': str, URI for tracking. - - 'nnunet_plans': str, optional, nnUNet plans. - - 'nnunet_trainer': str, optional, nnUNet trainer. - root_dir : str - Root directory where the configuration files will be generated. - script_dir : str - Directory containing the scripts to be used. - nvflare_exec : str - Path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "train" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Train nnUNet", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - {"id": "nnunet_processor", "path": "monai.nvflare.response_processor.nnUNetTrainProcessor", "args": {}}, - {"id": "json_generator", "path": "monai.nvflare.json_generator.nnUNetValSummaryJsonGenerator", "args": {}}, - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 6000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "data_dir": clients[client_id]["data_dir"], - "patient_id_in_file_identifier": clients[client_id]["patient_id_in_file_identifier"], - "modality_dict": clients[client_id]["modality_dict"], - "dataset_format": clients[client_id]["dataset_format"], - "nnunet_root_folder": clients[client_id]["nnunet_root_folder"], - "nnunet_config": { - "dataset_name_or_id": experiment["dataset_name_or_id"], - "experiment_name": experiment["experiment_name"], - }, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - if "bundle_root" in clients[client_id]: - client["executors"][0]["executor"]["args"]["bundle_root"] = clients[client_id]["bundle_root"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def prepare_bundle_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Prepare the configuration files for the nnUNet bundle and generate the job configurations for NVFlare. - - Parameters - ---------- - clients : dict - A dictionary containing client information. Keys are client IDs and values are dictionaries with client details. - experiment : dict - A dictionary containing experiment details such as 'experiment_name', 'tracking_uri', and optional - configurations like 'bundle_extra_config', 'nnunet_plans', and 'nnunet_trainer'. - root_dir : str - The root directory where the configuration files and job directories will be created. - script_dir : str - The directory containing the necessary scripts for NVFlare. - nvflare_exec : str - The path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "prepare_bundle" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = {"description": "Prepare nnUNet Bundle", "client_category": "Executor", "controller_type": "server"} - - meta = { - "name": f"{task_name}_nnUNet", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": 1, - "mandatory_clients": list(clients.keys()), - } - for client_id in clients: - meta["deploy_map"][f"{task_name}-client-{client_id}"] = [client_id] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "server": {"heart_beat_timeout": 600}, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "nnunet_processor", - "path": "monai.nvflare.response_processor.nnUNetBundlePrepareProcessor", - "args": {}, - } - ], - "workflows": [ - { - "id": "broadcast_and_process", - "name": "BroadcastAndProcess", - "args": { - "processor": "nnunet_processor", - "min_responses_required": 0, - "wait_time_after_min_received": 10, - "task_name": task_name, - "timeout": 600000, - }, - } - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_id in clients: - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "components": [], - "executors": [ - { - "tasks": [task_name], - "executor": { - "path": "monai.nvflare.nnunet_executor.nnUNetExecutor", - "args": { - "nnunet_config": {"experiment_name": experiment["experiment_name"]}, - "client_name": clients[client_id]["client_name"], - "tracking_uri": experiment["tracking_uri"], - }, - }, - } - ], - } - - if "bundle_extra_config" in experiment: - client["executors"][0]["executor"]["args"]["train_extra_configs"] = experiment["bundle_extra_config"] - if "nnunet_plans" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_plans"] = experiment["nnunet_plans"] - - if "nnunet_trainer" in experiment: - client["executors"][0]["executor"]["args"]["nnunet_config"]["nnunet_trainer"] = experiment["nnunet_trainer"] - - if "bundle_root" in clients[client_id]: - client["executors"][0]["executor"]["args"]["bundle_root"] = clients[client_id]["bundle_root"] - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}").mkdir( - parents=True, exist_ok=True - ) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-client-{client_id}", "config_fed_client.conf"), - "w", - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def train_fl_config(clients, experiment, root_dir, script_dir, nvflare_exec): - """ - Generate federated learning job configurations for NVFlare. - - Parameters - ---------- - clients : dict - Dictionary containing client names and their configurations. - experiment : dict - Dictionary containing experiment parameters such as number of rounds and local epochs. - root_dir : str - Root directory where the job configurations will be saved. - script_dir : str - Directory containing the necessary scripts for NVFlare. - nvflare_exec : str - Path to the NVFlare executable. - - Returns - ------- - None - """ - task_name = "train_fl_nnunet_bundle" - Path(root_dir).joinpath(task_name).mkdir(parents=True, exist_ok=True) - - info = { - "description": "Federated Learning with nnUNet-MONAI Bundle", - "client_category": "Executor", - "controller_type": "server", - } - - meta = { - "name": f"{task_name}", - "resource_spec": {}, - "deploy_map": {f"{task_name}-server": ["server"]}, - "min_clients": len(list(clients.keys())), - "mandatory_clients": list(clients.keys()), - } - - for client_name, client_config in clients.items(): - meta["deploy_map"][f"{task_name}-{client_name}"] = [client_name] - - with open(Path(root_dir).joinpath(task_name).joinpath("info.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(info))) - f.write("\n}") - - with open(Path(root_dir).joinpath(task_name).joinpath("meta.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(meta))) - f.write("\n}") - - server = { - "format_version": 2, - "min_clients": len(list(clients.keys())), - "num_rounds": experiment["num_rounds"], - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "persistor", - "path": "monai_nvflare.monai_bundle_persistor.MonaiBundlePersistor", - "args": { - "bundle_root": experiment["server_bundle_root"], - "config_train_filename": "configs/train.yaml", - "network_def_key": "network_def_fl", - }, - }, - {"id": "shareable_generator", "name": "FullModelShareableGenerator", "args": {}}, - { - "id": "aggregator", - "name": "InTimeAccumulateWeightedAggregator", - "args": {"expected_data_kind": "WEIGHT_DIFF"}, - }, - {"id": "model_selector", "name": "IntimeModelSelector", "args": {}}, - {"id": "model_locator", "name": "PTFileModelLocator", "args": {"pt_persistor_id": "persistor"}}, - {"id": "json_generator", "name": "ValidationJsonGenerator", "args": {}}, - ], - "workflows": [ - { - "id": "scatter_gather_ctl", - "name": "ScatterAndGather", - "args": { - "min_clients": "{min_clients}", - "num_rounds": "{num_rounds}", - "start_round": experiment["start_round"], - "wait_time_after_min_received": 10, - "aggregator_id": "aggregator", - "persistor_id": "persistor", - "shareable_generator_id": "shareable_generator", - "train_task_name": "train", - "train_timeout": 0, - }, - }, - { - "id": "cross_site_model_eval", - "name": "CrossSiteModelEval", - "args": { - "model_locator_id": "model_locator", - "submit_model_timeout": 600, - "validation_timeout": 6000, - "cleanup_models": True, - }, - }, - ], - } - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server").mkdir(parents=True, exist_ok=True) - with open(Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-server", "config_fed_server.conf"), "w") as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(server))) - f.write("\n}") - - for client_name, client_config in clients.items(): - client = { - "format_version": 2, - "task_result_filters": [], - "task_data_filters": [], - "executors": [ - { - "tasks": ["train", "submit_model", "validate"], - "executor": { - "id": "executor", - # "path": "monai_algo.ClientnnUNetAlgoExecutor", - "path": "monai_nvflare.client_algo_executor.ClientAlgoExecutor", - "args": {"client_algo_id": "client_algo", "key_metric": "Val_Dice"}, - }, - } - ], - "components": [ - { - "id": "client_algo", - # "path": "monai_algo.MonaiAlgonnUNet", - "path": "monai.fl.client.monai_algo.MonaiAlgo", - "args": { - "bundle_root": client_config["bundle_root"], - "config_train_filename": "configs/train.yaml", - "save_dict_key": "network_weights", - "local_epochs": experiment["local_epochs"], - "train_kwargs": {"nnunet_root_folder": client_config["nnunet_root_folder"]}, - }, - } - ], - } - - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-{client_name}").mkdir(parents=True, exist_ok=True) - with open( - Path(root_dir).joinpath(task_name).joinpath(f"{task_name}-{client_name}", "config_fed_client.conf"), "w" - ) as f: - f.write("{\n") - f.write(HOCONConverter.to_hocon(ConfigFactory.from_dict(client))) - f.write("\n}") - - subprocess.run( - [ - nvflare_exec, - "job", - "create", - "-j", - Path(root_dir).joinpath("jobs", task_name), - "-w", - Path(root_dir).joinpath(task_name), - "-sd", - script_dir, - "--force", - ] - ) - - -def generate_configs(client_files, experiment_file, script_dir, job_dir, nvflare_exec="nvflare"): - """ - Generate configuration files for NVFlare job. - - Parameters - ---------- - client_files : list of str - List of file paths to client configuration files. - experiment_file : str - File path to the experiment configuration file. - script_dir : str - Directory path where the scripts are located. - job_dir : str - Directory path where the job configurations will be saved. - nvflare_exec : str, optional - NVFlare executable command, by default "nvflare". - - Returns - ------- - None - """ - clients = {} - for client_id in client_files: - with open(client_id) as f: - client_name = Path(client_id).name - clients[client_name.split(".")[0]] = yaml.safe_load(f) - - with open(experiment_file) as f: - experiment = yaml.safe_load(f) - - check_client_packages_config(clients, experiment, job_dir, script_dir, nvflare_exec) - prepare_config(clients, experiment, job_dir, script_dir, nvflare_exec) - plan_and_preprocess_config(clients, experiment, job_dir, script_dir, nvflare_exec) - preprocess_config(clients, experiment, job_dir, script_dir, nvflare_exec) - train_config(clients, experiment, job_dir, script_dir, nvflare_exec) - prepare_bundle_config(clients, experiment, job_dir, script_dir, nvflare_exec) - train_fl_config(clients, experiment, job_dir, script_dir, nvflare_exec) diff --git a/monai/nvflare/nvflare_nnunet.py b/monai/nvflare/nvflare_nnunet.py deleted file mode 100644 index 3325d26ec8..0000000000 --- a/monai/nvflare/nvflare_nnunet.py +++ /dev/null @@ -1,696 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import json -import logging -import multiprocessing -import os -import pathlib -import random -import re -import shutil -import subprocess -from importlib.metadata import version -from pathlib import Path - -import mlflow -import numpy as np -import pandas as pd -import psutil -import yaml - -import monai -from monai.apps.nnunet import nnUNetV2Runner -from monai.apps.nnunet.nnunet_bundle import convert_monai_bundle_to_nnunet -from monai.bundle import ConfigParser - - -def train( - nnunet_root_dir, - experiment_name, - client_name, - tracking_uri, - dataset_name_or_id, - trainer_class_name="nnUNetTrainer", - nnunet_plans_name="nnUNetPlans", - run_with_bundle=False, - fold=0, - bundle_root=None, - mlflow_token=None, -): - """ - - Train a nnUNet model and log metrics to MLflow. - - Parameters - ---------- - nnunet_root_dir : str - Root directory for nnUNet. - experiment_name : str - Name of the MLflow experiment. - client_name : str - Name of the client. - tracking_uri : str - URI for MLflow tracking server. - dataset_name_or_id : str - Name or ID of the dataset. - trainer_class_name : str, optional - Name of the nnUNet trainer class, by default "nnUNetTrainer". - nnunet_plans_name : str, optional - Name of the nnUNet plans, by default "nnUNetPlans". - run_with_bundle : bool, optional - Whether to run with MONAI bundle, by default False. - fold : int, optional - Fold number for cross-validation, by default 0. - bundle_root : str, optional - Root directory for MONAI bundle, by default None. - mlflow_token : str, optional - Token for MLflow authentication, by default None. - - Returns - ------- - dict - Dictionary containing validation summary metrics. - """ - data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") - runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) - - if not run_with_bundle: - runner.train_single_model(config="3d_fullres", fold=fold) - else: - os.environ["BUNDLE_ROOT"] = bundle_root - os.environ["PYTHONPATH"] = os.environ["PYTHONPATH"] + ":" + bundle_root - monai.bundle.run( - config_file=Path(bundle_root).joinpath("configs/train.yaml"), - bundle_root=bundle_root, - nnunet_trainer_class_name=trainer_class_name, - mlflow_experiment_name=experiment_name, - mlflow_run_name="run_" + client_name, - tracking_uri=tracking_uri, - fold_id=fold, - nnunet_root_folder=nnunet_root_dir, - ) - nnunet_config = {"dataset_name_or_id": dataset_name_or_id, "nnunet_trainer": trainer_class_name} - convert_monai_bundle_to_nnunet(nnunet_config, bundle_root) - runner.train_single_model(config="3d_fullres", fold=fold, val="") - - if mlflow_token is not None: - os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token - if tracking_uri is not None: - mlflow.set_tracking_uri(tracking_uri) - - try: - mlflow.create_experiment(experiment_name) - except Exception as e: - print(e) - mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) - - filter = f""" - tags."client" = "{client_name}" - """ - - runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) - - validation_summary = os.path.join( - runner.nnunet_results, - runner.dataset_name, - f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", - f"fold_{fold}", - "validation", - "summary.json", - ) - - dataset_file = os.path.join( - runner.nnunet_results, - runner.dataset_name, - f"{trainer_class_name}__{nnunet_plans_name}__3d_fullres", - "dataset.json", - ) - - with open(dataset_file, "r") as f: - dataset_dict = json.load(f) - labels = dataset_dict["labels"] - labels = {str(v): k for k, v in labels.items()} - - with open(validation_summary, "r") as f: - validation_summary_dict = json.load(f) - - if len(runs) == 0: - with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): - for label in validation_summary_dict["mean"]: - for metric in validation_summary_dict["mean"][label]: - label_name = labels[label] - mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric])) - - else: - with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): - for label in validation_summary_dict["mean"]: - for metric in validation_summary_dict["mean"][label]: - label_name = labels[label] - mlflow.log_metric(f"{label_name}_{metric}", float(validation_summary_dict["mean"][label][metric])) - - return validation_summary_dict - - -def preprocess(nnunet_root_dir, dataset_name_or_id, nnunet_plans_file_path=None, trainer_class_name="nnUNetTrainer"): - """ - Preprocess the dataset for nnUNet training. - - Parameters - ---------- - nnunet_root_dir : str - The root directory of the nnUNet project. - dataset_name_or_id : str or int - The name or ID of the dataset to preprocess. - nnunet_plans_file_path : Path, optional - The file path to the nnUNet plans file. If None, default plans will be used. Default is None. - trainer_class_name : str, optional - The name of the trainer class to use. Default is "nnUNetTrainer". - - Returns - ------- - dict - The nnUNet plans dictionary. - """ - - data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") - runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) - - nnunet_plans_name = nnunet_plans_file_path.name.split(".")[0] - from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name - - dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id)) - - Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name).mkdir(parents=True, exist_ok=True) - - shutil.copy( - Path(nnunet_root_dir).joinpath("nnUNet_raw_data_base", dataset_name, "dataset.json"), - Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name, "dataset.json"), - ) - if nnunet_plans_file_path is not None: - with open(nnunet_plans_file_path, "r") as f: - nnunet_plans = json.load(f) - nnunet_plans["original_dataset_name"] = nnunet_plans["dataset_name"] - nnunet_plans["dataset_name"] = dataset_name - json.dump( - nnunet_plans, - open( - Path(nnunet_root_dir).joinpath("nnUNet_preprocessed", dataset_name, f"{nnunet_plans_name}.json"), - "w", - ), - indent=4, - ) - - runner.extract_fingerprints(npfp=2, verify_dataset_integrity=True) - runner.preprocess(c=["3d_fullres"], n_proc=[2], overwrite_plans_name=nnunet_plans_name) - - return nnunet_plans - - -def plan_and_preprocess( - nnunet_root_dir, - dataset_name_or_id, - client_name, - experiment_name, - tracking_uri, - mlflow_token=None, - nnunet_plans_name="nnUNetPlans", - trainer_class_name="nnUNetTrainer", -): - """ - Plan and preprocess the dataset using nnUNetV2Runner and log the plans to MLflow. - - Parameters - ---------- - nnunet_root_dir : str - The root directory of nnUNet. - dataset_name_or_id : str or int - The name or ID of the dataset to be processed. - client_name : str - The name of the client. - experiment_name : str - The name of the MLflow experiment. - tracking_uri : str - The URI of the MLflow tracking server. - mlflow_token : str, optional - The token for MLflow authentication (default is None). - nnunet_plans_name : str, optional - The name of the nnUNet plans (default is "nnUNetPlans"). - trainer_class_name : str, optional - The name of the nnUNet trainer class (default is "nnUNetTrainer"). - - Returns - ------- - dict - The nnUNet plans as a dictionary. - """ - - data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") - - runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir) - - runner.plan_and_process( - npfp=2, verify_dataset_integrity=True, c=["3d_fullres"], n_proc=[2], overwrite_plans_name=nnunet_plans_name - ) - - preprocessed_folder = runner.nnunet_preprocessed - - from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name - - dataset_name = maybe_convert_to_dataset_name(int(dataset_name_or_id)) - - with open(Path(preprocessed_folder).joinpath(f"{dataset_name}", nnunet_plans_name + ".json"), "r") as f: - nnunet_plans = json.load(f) - - if mlflow_token is not None: - os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token - if tracking_uri is not None: - mlflow.set_tracking_uri(tracking_uri) - - try: - mlflow.create_experiment(experiment_name) - except Exception as e: - print(e) - mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) - - filter = f""" - tags."client" = "{client_name}" - """ - - runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) - - if len(runs) == 0: - with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): - mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json") - - else: - with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): - mlflow.log_dict(nnunet_plans, nnunet_plans_name + ".json") - - return nnunet_plans - - -def prepare_data_folder( - data_dir, - nnunet_root_dir, - dataset_name_or_id, - modality_dict, - experiment_name, - client_name, - dataset_format, - modality_list = None, - tracking_uri=None, - mlflow_token=None, - subfolder_suffix=None, - patient_id_in_file_identifier=True, - trainer_class_name="nnUNetTrainer", -): - """ - Prepare the data folder for nnUNet training and log the data to MLflow. - - Parameters - ---------- - data_dir : str - Directory containing the dataset. - nnunet_root_dir : str - Root directory for nnUNet. - dataset_name_or_id : str - Name or ID of the dataset. - modality_dict : dict - Dictionary mapping modality IDs to file suffixes. - experiment_name : str - Name of the MLflow experiment. - client_name : str - Name of the client. - dataset_format : str - Format of the dataset. Supported formats are "subfolders", "decathlon", and "nnunet". - tracking_uri : str, optional - URI for MLflow tracking server. - modality_list : list, optional - List of modalities. Default is None. - mlflow_token : str, optional - Token for MLflow authentication. - subfolder_suffix : str, optional - Suffix for subfolder names. - patient_id_in_file_identifier : bool, optional - Whether patient ID is included in file identifier. Default is True. - trainer_class_name : str, optional - Name of the nnUNet trainer class. Default is "nnUNetTrainer". - - Returns - ------- - dict - Dictionary containing the training and testing data lists. - """ - if dataset_format == "subfolders": - if subfolder_suffix is not None: - data_list = { - "training": [ - { - modality_id: ( - str( - pathlib.Path(f.name).joinpath( - f.name[: -len(subfolder_suffix)] + modality_dict[modality_id] - ) - ) - if patient_id_in_file_identifier - else str(pathlib.Path(f.name).joinpath(modality_dict[modality_id])) - ) - for modality_id in modality_dict - } - for f in os.scandir(data_dir) - if f.is_dir() - ], - "testing": [], - } - else: - data_list = { - "training": [ - { - modality_id: ( - str(pathlib.Path(f.name).joinpath(f.name + modality_dict[modality_id])) - if patient_id_in_file_identifier - else str(pathlib.Path(f.name).joinpath(modality_dict[modality_id])) - ) - for modality_id in modality_dict - } - for f in os.scandir(data_dir) - if f.is_dir() - ], - "testing": [], - } - elif dataset_format == "decathlon" or dataset_format == "nnunet": - cases = [] - - for f in os.scandir(Path(data_dir).joinpath("imagesTr")): - if f.is_file(): - for modality_suffix in list(modality_dict.values()): - if f.name.endswith(modality_suffix) and modality_suffix != ".nii.gz": - cases.append(f.name[: -len(modality_suffix)]) - if len(np.unique(list(modality_dict.values()))) == 1 and ".nii.gz" in list(modality_dict.values()): - cases.append(f.name[: -len(".nii.gz")]) - cases = np.unique(cases) - data_list = { - "training": [ - { - modality_id: str(Path("imagesTr").joinpath(case + modality_dict[modality_id])) - for modality_id in modality_dict - if modality_id != "label" - } - for case in cases - ], - "testing": [], - } - for idx, case in enumerate(data_list["training"]): - modality_id = list(modality_dict.keys())[0] - case_id = Path(case[modality_id]).name[: -len(modality_dict[modality_id])] - data_list["training"][idx]["label"] = str(Path("labelsTr").joinpath(case_id + modality_dict["label"])) - else: - raise ValueError("Dataset format not supported") - - for idx, train_case in enumerate(data_list["training"]): - for modality_id in modality_dict: - data_list["training"][idx][modality_id + "_is_file"] = ( - Path(data_dir).joinpath(data_list["training"][idx][modality_id]).is_file() - ) - if "image" not in data_list["training"][idx] and modality_id != "label": - data_list["training"][idx]["image"] = data_list["training"][idx][modality_id] - data_list["training"][idx]["fold"] = 0 - - random.seed(42) - random.shuffle(data_list["training"]) - - data_list["testing"] = [data_list["training"][0]] - - num_folds = 5 - fold_size = len(data_list["training"]) // num_folds - for i in range(num_folds): - for j in range(fold_size): - data_list["training"][i * fold_size + j]["fold"] = i - - datalist_file = Path(data_dir).joinpath(f"{experiment_name}_folds.json") - with open(datalist_file, "w", encoding="utf-8") as f: - json.dump(data_list, f, ensure_ascii=False, indent=4) - - os.makedirs(nnunet_root_dir, exist_ok=True) - - if modality_list is None: - modality_list = [k for k in modality_dict.keys() if k != "label"] - - data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml") - data_src = { - "modality": modality_list, - "dataset_name_or_id": dataset_name_or_id, - "datalist": str(datalist_file), - "dataroot": str(data_dir), - } - - ConfigParser.export_config_file(data_src, data_src_cfg) - - if dataset_format != "nnunet": - runner = nnUNetV2Runner( - input_config=data_src_cfg, trainer_class_name=trainer_class_name, work_dir=nnunet_root_dir - ) - runner.convert_dataset() - else: - ... - - if mlflow_token is not None: - os.environ["MLFLOW_TRACKING_TOKEN"] = mlflow_token - if tracking_uri is not None: - mlflow.set_tracking_uri(tracking_uri) - - try: - mlflow.create_experiment(experiment_name) - mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) - except Exception as e: - print(e) - mlflow.set_experiment(experiment_id=(mlflow.get_experiment_by_name(experiment_name).experiment_id)) - - filter = f""" - tags."client" = "{client_name}" - """ - - runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=filter, order_by=["start_time DESC"]) - - try: - if len(runs) == 0: - with mlflow.start_run(run_name=f"run_{client_name}", tags={"client": client_name}): - mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json") - else: - with mlflow.start_run(run_id=runs.iloc[0].run_id, tags={"client": client_name}): - mlflow.log_table(pd.DataFrame.from_records(data_list["training"]), f"{client_name}_train.json") - except (BrokenPipeError, ConnectionError) as e: - logging.error(f"Failed to log data to MLflow: {e}") - - return data_list - - -def check_packages(packages): - """ - Check if the specified packages are installed and return a report. - - Parameters - ---------- - packages : list - A list of package names (str) or dictionaries with keys "import_name" and "package_name". - - Returns - ------- - dict - A dictionary where the keys are package names and the values are strings indicating whether - the package is installed and its version if applicable. - - Examples - -------- - >>> check_packages(["numpy", "nonexistent_package"]) - {'numpy': 'numpy 1.21.0 is installed.', 'nonexistent_package': 'nonexistent_package is not installed.'} - >>> check_packages([{"import_name": "torch", "package_name": "torch"}]) - {'torch': 'torch 1.9.0 is installed.'} - """ - report = {} - for package in packages: - try: - if isinstance(package, dict): - __import__(package["import_name"]) - package_version = version(package["package_name"]) - name = package["package_name"] - print(f"{name} {package_version} is installed.") - report[name] = f"{name} {package_version} is installed." - else: - - __import__(package) - package_version = version(package) - print(f"{package} {package_version} is installed.") - report[package] = f"{package} {package_version} is installed." - - except ImportError: - print(f"{package} is not installed.") - report[package] = f"{package} is not installed." - - return report - - -def check_host_config(): - """ - Collects and returns the host configuration details including GPU, CPU, and memory information. - - Returns - ------- - dict - A dictionary containing the following keys and their corresponding values: - - Config values from `monai.config.deviceconfig.get_config_values()` - - Optional config values from `monai.config.deviceconfig.get_optional_config_values()` - - GPU information including number of GPUs, CUDA version, cuDNN version, and GPU names and memory - - CPU core count - - Total memory in GB - - Memory usage percentage - """ - params_dict = {} - config_values = monai.config.deviceconfig.get_config_values() - for k in config_values: - params_dict[re.sub("[()]", " ", str(k))] = config_values[k] - optional_config_values = monai.config.deviceconfig.get_optional_config_values() - - for k in optional_config_values: - params_dict[re.sub("[()]", " ", str(k))] = optional_config_values[k] - - gpu_info = monai.config.deviceconfig.get_gpu_info() - allowed_keys = ["Num GPUs", "Has Cuda", "CUDA Version", "cuDNN enabled", "cuDNN Version"] - for i in range(gpu_info["Num GPUs"]): - allowed_keys.append(f"GPU {i} Name") - allowed_keys.append(f"GPU {i} Total memory GB ") - - for k in gpu_info: - if re.sub("[()]", " ", str(k)) in allowed_keys: - params_dict[re.sub("[()]", " ", str(k))] = str(gpu_info[k]) - - with open("nvidia-smi.log", "w") as f_e: - subprocess.run("nvidia-smi", stderr=f_e, stdout=f_e) - - params_dict["CPU_Cores"] = multiprocessing.cpu_count() - - vm = psutil.virtual_memory() - - params_dict["Total Memory"] = vm.total / (1024 * 1024 * 1024) - params_dict["Memory Used %"] = vm.percent - - return params_dict - - -def prepare_bundle(bundle_config, train_extra_configs=None): - """ - Prepare the bundle configuration for training and evaluation. - - Parameters - ---------- - bundle_config : dict - Dictionary containing the bundle configuration. Expected keys are: - - "bundle_root": str, root directory of the bundle. - - "tracking_uri": str, URI for tracking. - - "mlflow_experiment_name": str, name of the MLflow experiment. - - "mlflow_run_name": str, name of the MLflow run. - - "nnunet_plans_identifier": str, optional, identifier for nnUNet plans. - - "nnunet_trainer_class_name": str, optional, class name for nnUNet trainer. - train_extra_configs : dict, optional - Additional configurations for training. If provided, expected keys are: - - "resume_epoch": int, epoch to resume training from. - - Any other key-value pairs to be added to the training configuration. - - Returns - ------- - None - """ - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml")) as f: - train_config = yaml.safe_load(f) - train_config["bundle_root"] = bundle_config["bundle_root"] - train_config["tracking_uri"] = bundle_config["tracking_uri"] - train_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"] - train_config["mlflow_run_name"] = bundle_config["mlflow_run_name"] - - train_config["data_src_cfg"] = "$@nnunet_root_folder+'/data_src_cfg.yaml'" - train_config["runner"] = { - "_target_": "nnUNetV2Runner", - "input_config": "$@data_src_cfg", - "trainer_class_name": "@nnunet_trainer_class_name", - "work_dir": "@nnunet_root_folder", - } - - train_config["network"] = "$@nnunet_trainer.network._orig_mod" - - train_handlers = train_config["train_handlers"]["handlers"] - - for idx, handler in enumerate(train_handlers): - if handler["_target_"] == "ValidationHandler": - train_handlers.pop(idx) - break - - train_config["train_handlers"]["handlers"] = train_handlers - - if train_extra_configs is not None and "resume_epoch" in train_extra_configs: - resume_epoch = train_extra_configs["resume_epoch"] - train_config["initialize"] = [ - "$monai.utils.set_determinism(seed=123)", - "$@runner.dataset_name_or_id", - f"$src.trainer.reload_checkpoint(@train#trainer, {resume_epoch}, @iterations, @ckpt_dir, @lr_scheduler)", - ] - else: - train_config["initialize"] = ["$monai.utils.set_determinism(seed=123)", "$@runner.dataset_name_or_id"] - - if "Val_Dice" in train_config["val_key_metric"]: - train_config["val_key_metric"] = {"Val_Dice_Local": train_config["val_key_metric"]["Val_Dice"]} - - if "Val_Dice_per_class" in train_config["val_additional_metrics"]: - train_config["val_additional_metrics"] = { - "Val_Dice_per_class_Local": train_config["val_additional_metrics"]["Val_Dice_per_class"] - } - if "nnunet_plans_identifier" in bundle_config: - train_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"] - - if "nnunet_trainer_class_name" in bundle_config: - train_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"] - - if train_extra_configs is not None: - for key in train_extra_configs: - train_config[key] = train_extra_configs[key] - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.json"), "w") as f: - json.dump(train_config, f) - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "train.yaml"), "w") as f: - yaml.dump(train_config, f) - - if not Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml").exists(): - shutil.copy( - Path(bundle_config["bundle_root"]).joinpath("nnUNet", "evaluator", "evaluator.yaml"), - Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), - ) - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml")) as f: - evaluate_config = yaml.safe_load(f) - evaluate_config["bundle_root"] = bundle_config["bundle_root"] - - evaluate_config["tracking_uri"] = bundle_config["tracking_uri"] - evaluate_config["mlflow_experiment_name"] = bundle_config["mlflow_experiment_name"] - evaluate_config["mlflow_run_name"] = bundle_config["mlflow_run_name"] - - if "nnunet_plans_identifier" in bundle_config: - evaluate_config["nnunet_plans_identifier"] = bundle_config["nnunet_plans_identifier"] - if "nnunet_trainer_class_name" in bundle_config: - evaluate_config["nnunet_trainer_class_name"] = bundle_config["nnunet_trainer_class_name"] - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.json"), "w") as f: - json.dump(evaluate_config, f) - - with open(Path(bundle_config["bundle_root"]).joinpath("configs", "evaluate.yaml"), "w") as f: - yaml.dump(evaluate_config, f) diff --git a/monai/nvflare/response_processor.py b/monai/nvflare/response_processor.py deleted file mode 100644 index a02d307220..0000000000 --- a/monai/nvflare/response_processor.py +++ /dev/null @@ -1,342 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from nvflare.apis.client import Client -from nvflare.apis.dxo import DataKind, from_shareable -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable -from nvflare.app_common.abstract.response_processor import ResponseProcessor - - -class nnUNetPrepareProcessor(ResponseProcessor): - """ - A processor class for preparing nnUNet data in a federated learning context. - - Methods - ------- - __init__(): - Initializes the nnUNetPrepareProcessor with an empty data dictionary. - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable: - Creates and returns a Shareable object for the given task name. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - Processes the response from a client. Validates the response and updates the data dictionary if valid. - final_process(fl_ctx: FLContext) -> bool: - Finalizes the processing by setting the client data dictionary in the federated learning context. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - self.data_dict = {} - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - data_dict = dxo.data - - if not data_dict: - self.log_error(fl_ctx, f"No dataset_dict found from client {client.name}") - return False - - self.data_dict[client.name] = data_dict - - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - if not self.data_dict: - self.log_error(fl_ctx, "no data_prepare_dict from clients") - return False - - # must set sticky to True so other controllers can get it! - fl_ctx.set_prop("client_data_dict", self.data_dict, private=True, sticky=True) - return True - - -class nnUNetPackageReportProcessor(ResponseProcessor): - """ - A processor for handling nnUNet package reports in a federated learning context. - - Attributes - ---------- - package_report : dict - A dictionary to store package reports from clients. - - Methods - ------- - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable - Creates task data for a given task name and federated learning context. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool - Processes the response from a client for a given task name and federated learning context. - final_process(fl_ctx: FLContext) -> bool - Final processing step to handle the collected package reports. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - self.package_report = {} - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - package_report = dxo.data - - if not package_report: - self.log_error(fl_ctx, f"No package_report found from client {client.name}") - return False - - self.package_report[client.name] = package_report - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - if not self.package_report: - self.log_error(fl_ctx, "no plan_dict from client") - return False - - # must set sticky to True so other controllers can get it! - fl_ctx.set_prop("package_report", self.package_report, private=True, sticky=True) - return True - - -class nnUNetPlanProcessor(ResponseProcessor): - """ - nnUNetPlanProcessor is a class that processes responses from clients in a federated learning context. - It inherits from the ResponseProcessor class and is responsible for handling and validating the - responses, extracting the necessary data, and storing it for further use. - - Attributes - ---------- - plan_dict : dict - A dictionary to store the plan data received from clients. - - Methods - ------- - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable - Creates and returns a Shareable object for the given task name. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool - Processes the response from a client, validates it, and stores the plan data if valid. - final_process(fl_ctx: FLContext) -> bool - Finalizes the processing by setting the plan data in the federated learning context. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - self.plan_dict = {} - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - plan_dict = dxo.data - - if not plan_dict: - self.log_error(fl_ctx, f"No plan_dict found from client {client.name}") - return False - - self.plan_dict[client.name] = plan_dict - - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - if not self.plan_dict: - self.log_error(fl_ctx, "no plan_dict from client") - return False - - # must set sticky to True so other controllers can get it! - fl_ctx.set_prop("nnunet_plans", self.plan_dict, private=True, sticky=True) - return True - - -class nnUNetTrainProcessor(ResponseProcessor): - """ - A processor class for handling training responses in the nnUNet framework. - - Attributes - ---------- - val_summary_dict : dict - A dictionary to store validation summaries from clients. - Methods - ------- - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable - Creates task data for a given task name and FLContext. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool - Processes the response from a client for a given task name and FLContext. - final_process(fl_ctx: FLContext) -> bool - Final processing step to handle the collected validation summaries. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - self.val_summary_dict = {} - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - val_summary_dict = dxo.data - - if not val_summary_dict: - self.log_error(fl_ctx, f"No val_summary_dict found from client {client.name}") - return False - - self.val_summary_dict[client.name] = val_summary_dict - - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - if not self.val_summary_dict: - self.log_error(fl_ctx, "no val_summary_dict from client") - return False - - # must set sticky to True so other controllers can get it! - fl_ctx.set_prop("val_summary_dict", self.val_summary_dict, private=True, sticky=True) - return True - - -class nnUNetBundlePrepareProcessor(ResponseProcessor): - """ - A processor class for preparing nnUNet bundles in a federated learning context. - - Methods - ------- - __init__(): - Initializes the nnUNetBundlePrepareProcessor instance. - create_task_data(task_name: str, fl_ctx: FLContext) -> Shareable: - Creates task data for a given task name and federated learning context. - process_client_response(client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - Processes the response from a client and validates it. - final_process(fl_ctx: FLContext) -> bool: - Final processing step after all client responses have been processed. - """ - - def __init__(self): - ResponseProcessor.__init__(self) - - def create_task_data(self, task_name: str, fl_ctx: FLContext) -> Shareable: - return Shareable() - - def process_client_response(self, client: Client, task_name: str, response: Shareable, fl_ctx: FLContext) -> bool: - if not isinstance(response, Shareable): - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " f"response must be Shareable but got {type(response)}", - ) - return False - - try: - dxo = from_shareable(response) - - except Exception: - self.log_exception(fl_ctx, f"bad response from client {client.name}: " f"it does not contain DXO") - return False - - if dxo.data_kind != DataKind.COLLECTION: - self.log_error( - fl_ctx, - f"bad response from client {client.name}: " - f"data_kind should be DataKind.COLLECTION but got {dxo.data_kind}", - ) - return False - - return True - - def final_process(self, fl_ctx: FLContext) -> bool: - - return True From 6d3fb0c1538b0306ede659f0f90c6445fb846dd1 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Wed, 2 Apr 2025 14:45:34 +0000 Subject: [PATCH 66/67] Update torch version requirement in requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 77221491e6..ad394ce807 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -torch>=2.3.0a0; sys_platform != 'win32' +torch>=2.3.0; sys_platform != 'win32' torch>=2.4.1; sys_platform == 'win32' numpy>=1.24,<3.0 From fed603b0fcdd3b5b47fe6c3177c264eb7d7b5df4 Mon Sep 17 00:00:00 2001 From: Simone Bendazzoli Date: Thu, 3 Apr 2025 06:51:14 +0000 Subject: [PATCH 67/67] DCO Remediation Commit for Simone Bendazzoli I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 7d60fd74e52867c34178e4c530a01b6d2b0b9c07 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 3b13218c349f123e6b44a5c02310d496afc216db I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: fbf6105cf884438fdd8f9059809b98bcfd48680e I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: d1035ca2a4ed9bf17badf56e5fcc682209a3b9bb I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 47798af8627d66daf5ebd962764aff81d723da87 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 0578b22e5a49fdd97997529e1270fc14312c7b96 I, Simone Bendazzoli , hereby add my Signed-off-by to this commit: 6d3fb0c1538b0306ede659f0f90c6445fb846dd1 Signed-off-by: Simone Bendazzoli --- monai/apps/nnunet/nnunet_bundle.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index 14ed2636c5..e358cd4b99 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -119,12 +119,7 @@ def get_nnunet_trainer( from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint nnunet_trainer = get_trainer_from_args( - str(dataset_name_or_id), - configuration, - fold, - trainer_class_name, - plans_identifier, - device=torch.device(device), + str(dataset_name_or_id), configuration, fold, trainer_class_name, plans_identifier, device=torch.device(device) ) if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing