Skip to content

Commit d02ba11

Browse files
surajpaibKumoLiu
andauthored
Fix generalized dice computation (#7970)
Fixes #7966 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Suraj Pai <[email protected]> Signed-off-by: Suraj Pai <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent b539cbb commit d02ba11

File tree

2 files changed

+201
-94
lines changed

2 files changed

+201
-94
lines changed

monai/metrics/generalized_dice.py

Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,47 @@
1414
import torch
1515

1616
from monai.metrics.utils import do_metric_reduction, ignore_background
17-
from monai.utils import MetricReduction, Weight, look_up_option
17+
from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option
1818

1919
from .metric import CumulativeIterationMetric
2020

2121

2222
class GeneralizedDiceScore(CumulativeIterationMetric):
23-
"""Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
23+
"""
24+
Compute the Generalized Dice Score metric between tensors.
2425
26+
This metric is the complement of the Generalized Dice Loss defined in:
2527
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
26-
loss function for highly unbalanced segmentations. DLMIA 2017.
28+
loss function for highly unbalanced segmentations. DLMIA 2017.
2729
28-
The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
29-
or batch-first tensors, i.e., CHW[D] or BCHW[D].
30+
The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
3031
3132
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
3233
3334
Args:
34-
include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
35+
include_background: Whether to include the background class (assumed to be in channel 0) in the
3536
score computation. Defaults to True.
36-
reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
37-
{``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
38-
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
37+
reduction: Define mode of reduction to the metrics. Available reduction modes:
38+
{``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
39+
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
40+
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
3941
ground truth volume into a weight factor. Defaults to ``"square"``.
4042
4143
Raises:
42-
ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}.
44+
ValueError: When the `reduction` is not one of MetricReduction enum.
4345
"""
4446

47+
@deprecated_arg_default(
48+
"reduction",
49+
old_default=MetricReduction.MEAN_BATCH,
50+
new_default=MetricReduction.MEAN,
51+
since="1.4.0",
52+
replaced="1.5.0",
53+
msg_suffix=(
54+
"Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
55+
"If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
56+
),
57+
)
4558
def __init__(
4659
self,
4760
include_background: bool = True,
@@ -50,79 +63,90 @@ def __init__(
5063
) -> None:
5164
super().__init__()
5265
self.include_background = include_background
53-
reduction_options = [
54-
"none",
55-
"mean_batch",
56-
"sum_batch",
57-
MetricReduction.NONE,
58-
MetricReduction.MEAN_BATCH,
59-
MetricReduction.SUM_BATCH,
60-
]
61-
self.reduction = reduction
62-
if self.reduction not in reduction_options:
63-
raise ValueError(f"reduction must be one of {reduction_options}")
66+
self.reduction = look_up_option(reduction, MetricReduction)
6467
self.weight_type = look_up_option(weight_type, Weight)
68+
self.sum_over_classes = self.reduction in {
69+
MetricReduction.SUM,
70+
MetricReduction.MEAN,
71+
MetricReduction.MEAN_CHANNEL,
72+
MetricReduction.SUM_CHANNEL,
73+
}
6574

6675
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
67-
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
76+
"""
77+
Computes the Generalized Dice Score and returns a tensor with its per image values.
6878
6979
Args:
70-
y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
80+
y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
7181
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
72-
y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
82+
y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
83+
84+
Returns:
85+
torch.Tensor: Generalized Dice Score averaged across batch and class
7386
7487
Raises:
75-
ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
88+
ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
7689
"""
7790
return compute_generalized_dice(
78-
y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
91+
y_pred=y_pred,
92+
y=y,
93+
include_background=self.include_background,
94+
weight_type=self.weight_type,
95+
sum_over_classes=self.sum_over_classes,
7996
)
8097

98+
@deprecated_arg(
99+
"reduction",
100+
since="1.3.3",
101+
removed="1.7.0",
102+
msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute",
103+
)
81104
def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
82105
"""
83106
Execute reduction logic for the output of `compute_generalized_dice`.
84107
85-
Args:
86-
reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
87-
Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
88-
Defaults to ``"mean"``. If "none", will not do reduction.
108+
Returns:
109+
torch.Tensor: Aggregated metric value.
110+
111+
Raises:
112+
ValueError: If the data to aggregate is not a PyTorch Tensor.
89113
"""
90114
data = self.get_buffer()
91115
if not isinstance(data, torch.Tensor):
92116
raise ValueError("The data to aggregate must be a PyTorch Tensor.")
93117

94-
# Validate reduction argument if specified
95-
if reduction is not None:
96-
reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
97-
if reduction not in reduction_options:
98-
raise ValueError(f"reduction must be one of {reduction_options}")
99-
100118
# Do metric reduction and return
101-
f, _ = do_metric_reduction(data, reduction or self.reduction)
119+
f, _ = do_metric_reduction(data, self.reduction)
102120

103121
return f
104122

105123

106124
def compute_generalized_dice(
107-
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE
125+
y_pred: torch.Tensor,
126+
y: torch.Tensor,
127+
include_background: bool = True,
128+
weight_type: Weight | str = Weight.SQUARE,
129+
sum_over_classes: bool = False,
108130
) -> torch.Tensor:
109-
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
131+
"""
132+
Computes the Generalized Dice Score and returns a tensor with its per image values.
110133
111134
Args:
112-
y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
135+
y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
113136
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
114137
remaining are the spatial dimensions.
115-
y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
116-
include_background (bool, optional): whether to include score computation on the first channel of the
138+
y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
139+
include_background: Whether to include score computation on the first channel of the
117140
predicted output. Defaults to True.
118141
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
119142
transform ground truth volume into a weight factor. Defaults to ``"square"``.
143+
sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.
120144
121145
Returns:
122-
torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
146+
torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
123147
124148
Raises:
125-
ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
149+
ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
126150
or `y_pred` and `y` don't have the same shape.
127151
"""
128152
# Ensure tensors have at least 3 dimensions and have the same shape
@@ -158,16 +182,21 @@ def compute_generalized_dice(
158182
b[infs] = 0
159183
b[infs] = torch.max(b)
160184

161-
# Compute the weighted numerator and denominator, summing along the class axis
162-
numer = 2.0 * (intersection * w).sum(dim=1)
163-
denom = (denominator * w).sum(dim=1)
185+
# Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
186+
if sum_over_classes:
187+
numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
188+
denom = (denominator * w).sum(dim=1, keepdim=True)
189+
y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
190+
else:
191+
numer = 2.0 * (intersection * w)
192+
denom = denominator * w
193+
y_pred_o = y_pred_o
164194

165195
# Compute the score
166196
generalized_dice_score = numer / denom
167197

168198
# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
169199
# Where denom == 0 but the prediction volume is not 0, score is 0
170-
y_pred_o = y_pred_o.sum(dim=-1)
171200
denom_zeros = denom == 0
172201
generalized_dice_score[denom_zeros] = torch.where(
173202
(y_pred_o == 0)[denom_zeros],

0 commit comments

Comments
 (0)