14
14
import torch
15
15
16
16
from monai .metrics .utils import do_metric_reduction , ignore_background
17
- from monai .utils import MetricReduction , Weight , look_up_option
17
+ from monai .utils import MetricReduction , Weight , deprecated_arg , deprecated_arg_default , look_up_option
18
18
19
19
from .metric import CumulativeIterationMetric
20
20
21
21
22
22
class GeneralizedDiceScore (CumulativeIterationMetric ):
23
- """Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
23
+ """
24
+ Compute the Generalized Dice Score metric between tensors.
24
25
26
+ This metric is the complement of the Generalized Dice Loss defined in:
25
27
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
26
- loss function for highly unbalanced segmentations. DLMIA 2017.
28
+ loss function for highly unbalanced segmentations. DLMIA 2017.
27
29
28
- The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
29
- or batch-first tensors, i.e., CHW[D] or BCHW[D].
30
+ The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
30
31
31
32
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
32
33
33
34
Args:
34
- include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
35
+ include_background: Whether to include the background class (assumed to be in channel 0) in the
35
36
score computation. Defaults to True.
36
- reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
37
- {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
38
- weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
37
+ reduction: Define mode of reduction to the metrics. Available reduction modes:
38
+ {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
39
+ ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
40
+ weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
39
41
ground truth volume into a weight factor. Defaults to ``"square"``.
40
42
41
43
Raises:
42
- ValueError: when the `weight_type ` is not one of {``"none"``, ``"mean"``, ``"sum"``} .
44
+ ValueError: When the `reduction ` is not one of MetricReduction enum .
43
45
"""
44
46
47
+ @deprecated_arg_default (
48
+ "reduction" ,
49
+ old_default = MetricReduction .MEAN_BATCH ,
50
+ new_default = MetricReduction .MEAN ,
51
+ since = "1.4.0" ,
52
+ replaced = "1.5.0" ,
53
+ msg_suffix = (
54
+ "Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
55
+ "If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
56
+ ),
57
+ )
45
58
def __init__ (
46
59
self ,
47
60
include_background : bool = True ,
@@ -50,79 +63,90 @@ def __init__(
50
63
) -> None :
51
64
super ().__init__ ()
52
65
self .include_background = include_background
53
- reduction_options = [
54
- "none" ,
55
- "mean_batch" ,
56
- "sum_batch" ,
57
- MetricReduction .NONE ,
58
- MetricReduction .MEAN_BATCH ,
59
- MetricReduction .SUM_BATCH ,
60
- ]
61
- self .reduction = reduction
62
- if self .reduction not in reduction_options :
63
- raise ValueError (f"reduction must be one of { reduction_options } " )
66
+ self .reduction = look_up_option (reduction , MetricReduction )
64
67
self .weight_type = look_up_option (weight_type , Weight )
68
+ self .sum_over_classes = self .reduction in {
69
+ MetricReduction .SUM ,
70
+ MetricReduction .MEAN ,
71
+ MetricReduction .MEAN_CHANNEL ,
72
+ MetricReduction .SUM_CHANNEL ,
73
+ }
65
74
66
75
def _compute_tensor (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
67
- """Computes the Generalized Dice Score and returns a tensor with its per image values.
76
+ """
77
+ Computes the Generalized Dice Score and returns a tensor with its per image values.
68
78
69
79
Args:
70
- y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
80
+ y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
71
81
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
72
- y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
82
+ y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
83
+
84
+ Returns:
85
+ torch.Tensor: Generalized Dice Score averaged across batch and class
73
86
74
87
Raises:
75
- ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
88
+ ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
76
89
"""
77
90
return compute_generalized_dice (
78
- y_pred = y_pred , y = y , include_background = self .include_background , weight_type = self .weight_type
91
+ y_pred = y_pred ,
92
+ y = y ,
93
+ include_background = self .include_background ,
94
+ weight_type = self .weight_type ,
95
+ sum_over_classes = self .sum_over_classes ,
79
96
)
80
97
98
+ @deprecated_arg (
99
+ "reduction" ,
100
+ since = "1.3.3" ,
101
+ removed = "1.7.0" ,
102
+ msg_suffix = "Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute" ,
103
+ )
81
104
def aggregate (self , reduction : MetricReduction | str | None = None ) -> torch .Tensor :
82
105
"""
83
106
Execute reduction logic for the output of `compute_generalized_dice`.
84
107
85
- Args:
86
- reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
87
- Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
88
- Defaults to ``"mean"``. If "none", will not do reduction.
108
+ Returns:
109
+ torch.Tensor: Aggregated metric value.
110
+
111
+ Raises:
112
+ ValueError: If the data to aggregate is not a PyTorch Tensor.
89
113
"""
90
114
data = self .get_buffer ()
91
115
if not isinstance (data , torch .Tensor ):
92
116
raise ValueError ("The data to aggregate must be a PyTorch Tensor." )
93
117
94
- # Validate reduction argument if specified
95
- if reduction is not None :
96
- reduction_options = ["none" , "mean" , "sum" , "mean_batch" , "sum_batch" ]
97
- if reduction not in reduction_options :
98
- raise ValueError (f"reduction must be one of { reduction_options } " )
99
-
100
118
# Do metric reduction and return
101
- f , _ = do_metric_reduction (data , reduction or self .reduction )
119
+ f , _ = do_metric_reduction (data , self .reduction )
102
120
103
121
return f
104
122
105
123
106
124
def compute_generalized_dice (
107
- y_pred : torch .Tensor , y : torch .Tensor , include_background : bool = True , weight_type : Weight | str = Weight .SQUARE
125
+ y_pred : torch .Tensor ,
126
+ y : torch .Tensor ,
127
+ include_background : bool = True ,
128
+ weight_type : Weight | str = Weight .SQUARE ,
129
+ sum_over_classes : bool = False ,
108
130
) -> torch .Tensor :
109
- """Computes the Generalized Dice Score and returns a tensor with its per image values.
131
+ """
132
+ Computes the Generalized Dice Score and returns a tensor with its per image values.
110
133
111
134
Args:
112
- y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
135
+ y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
113
136
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
114
137
remaining are the spatial dimensions.
115
- y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
116
- include_background (bool, optional): whether to include score computation on the first channel of the
138
+ y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
139
+ include_background: Whether to include score computation on the first channel of the
117
140
predicted output. Defaults to True.
118
141
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
119
142
transform ground truth volume into a weight factor. Defaults to ``"square"``.
143
+ sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.
120
144
121
145
Returns:
122
- torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
146
+ torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
123
147
124
148
Raises:
125
- ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
149
+ ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
126
150
or `y_pred` and `y` don't have the same shape.
127
151
"""
128
152
# Ensure tensors have at least 3 dimensions and have the same shape
@@ -158,16 +182,21 @@ def compute_generalized_dice(
158
182
b [infs ] = 0
159
183
b [infs ] = torch .max (b )
160
184
161
- # Compute the weighted numerator and denominator, summing along the class axis
162
- numer = 2.0 * (intersection * w ).sum (dim = 1 )
163
- denom = (denominator * w ).sum (dim = 1 )
185
+ # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
186
+ if sum_over_classes :
187
+ numer = 2.0 * (intersection * w ).sum (dim = 1 , keepdim = True )
188
+ denom = (denominator * w ).sum (dim = 1 , keepdim = True )
189
+ y_pred_o = y_pred_o .sum (dim = - 1 , keepdim = True )
190
+ else :
191
+ numer = 2.0 * (intersection * w )
192
+ denom = denominator * w
193
+ y_pred_o = y_pred_o
164
194
165
195
# Compute the score
166
196
generalized_dice_score = numer / denom
167
197
168
198
# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
169
199
# Where denom == 0 but the prediction volume is not 0, score is 0
170
- y_pred_o = y_pred_o .sum (dim = - 1 )
171
200
denom_zeros = denom == 0
172
201
generalized_dice_score [denom_zeros ] = torch .where (
173
202
(y_pred_o == 0 )[denom_zeros ],
0 commit comments