@@ -49,25 +49,25 @@ def get_nnunet_trainer(
49
49
The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
50
50
optimizer, loss function, DataLoader, etc.
51
51
52
- ```python
53
- from monai.apps import SupervisedTrainer
54
- from monai.bundle.nnunet import get_nnunet_trainer
55
-
56
- dataset_name_or_id = 'Task101_PROSTATE'
57
- fold = 0
58
- configuration = '3d_fullres'
59
- nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
60
-
61
- trainer = SupervisedTrainer(
62
- device=nnunet_trainer.device,
63
- max_epochs =nnunet_trainer.num_epochs ,
64
- train_data_loader =nnunet_trainer.dataloader_train ,
65
- network =nnunet_trainer.network ,
66
- optimizer =nnunet_trainer.optimizer ,
67
- loss_function =nnunet_trainer.loss_function ,
68
- epoch_length =nnunet_trainer.num_iterations_per_epoch ,
69
-
70
- ```
52
+ Example::
53
+
54
+ from monai.apps import SupervisedTrainer
55
+ from monai.bundle.nnunet import get_nnunet_trainer
56
+
57
+ dataset_name_or_id = 'Task101_PROSTATE'
58
+ fold = 0
59
+ configuration = '3d_fullres'
60
+ nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
61
+
62
+ trainer = SupervisedTrainer(
63
+ device =nnunet_trainer.device ,
64
+ max_epochs =nnunet_trainer.num_epochs ,
65
+ train_data_loader =nnunet_trainer.dataloader_train ,
66
+ network =nnunet_trainer.network ,
67
+ optimizer =nnunet_trainer.optimizer ,
68
+ loss_function =nnunet_trainer.loss_function ,
69
+ epoch_length=nnunet_trainer.num_iterations_per_epoch,
70
+ )
71
71
72
72
Parameters
73
73
----------
@@ -162,16 +162,19 @@ class ModelnnUNetWrapper(torch.nn.Module):
162
162
The folder path where the model and related files are stored.
163
163
model_name : str, optional
164
164
The name of the model file, by default "model.pt".
165
+
165
166
Attributes
166
167
----------
167
- predictor : object
168
- The predictor object used for inference.
168
+ predictor : nnUNetPredictor
169
+ The nnUNet predictor object used for inference.
169
170
network_weights : torch.nn.Module
170
171
The network weights of the model.
172
+
171
173
Methods
172
174
-------
173
175
forward(x)
174
176
Perform forward pass and prediction on the input data.
177
+
175
178
Notes
176
179
-----
177
180
This class integrates nnUNet model with MONAI framework by loading necessary configurations,
@@ -183,7 +186,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
183
186
self .predictor = predictor
184
187
185
188
model_training_output_dir = model_folder
186
- use_folds = "0"
189
+ use_folds = [ "0" ]
187
190
188
191
from nnunetv2 .utilities .plans_handling .plans_handler import PlansManager
189
192
@@ -290,27 +293,28 @@ def forward(self, x):
290
293
291
294
def get_nnunet_monai_predictor (model_folder , model_name = "model.pt" ):
292
295
"""
293
- Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor.
296
+ Initializes and returns a ` nnUNetMONAIModelWrapper` containing the corresponding ` nnUNetPredictor` .
294
297
The model folder should contain the following files, created during training:
295
- - dataset.json: from the nnUNet results folder.
296
- - plans .json: from the nnUNet results folder.
297
- - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
298
- (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`).
299
- - model.pt: The checkpoint file containing the model weights.
300
-
298
+
299
+ - dataset .json: from the nnUNet results folder
300
+ - plans.json: from the nnUNet results folder
301
+ - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`)
302
+ - model.pt: The checkpoint file containing the model weights.
303
+
301
304
The returned wrapper object can be used for inference with MONAI framework:
302
- ```python
303
- from monai.bundle.nnunet import get_nnunet_monai_predictor
305
+
306
+ Example::
307
+
308
+ from monai.bundle.nnunet import get_nnunet_monai_predictor
304
309
305
- model_folder = 'path/to/monai_bundle/model'
306
- model_name = 'model.pt'
307
- wrapper = get_nnunet_monai_predictor(model_folder, model_name)
310
+ model_folder = 'path/to/monai_bundle/model'
311
+ model_name = 'model.pt'
312
+ wrapper = get_nnunet_monai_predictor(model_folder, model_name)
308
313
309
- # Perform inference
310
- input_data = ...
311
- output = wrapper(input_data)
314
+ # Perform inference
315
+ input_data = ...
316
+ output = wrapper(input_data)
312
317
313
- ```
314
318
315
319
Parameters
316
320
----------
0 commit comments