Description
Describe the bug
compute_generalized_dice
does not return values as highlighted in the docs.
import torch
from monai.metrics import compute_generalized_dice, GeneralizedDiceScore
a = torch.ones((20, 10, 64, 128, 128))
b = torch.ones((20, 10, 64, 128, 128))
compute_generalized_dice(a, b).shape
This returns a shape of torch.Size([20])
whereas per the documentation it should return torch.Size([20, 10])
https://docs.monai.io/en/stable/metrics.html#generalized-dice-score
This leads to problems that propagate to GeneralizedDiceScore
as below,
generalized_dice_score = GeneralizedDiceScore()
generalized_dice_score(a, b)
Aggregating this over different reductions shows unexpected results and errors.
Case 1: generalized_dice_score.aggregate(reduction="sum")
and generalized_dice_score.aggregate(reduction="mean")
return this error IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
Case 2: generalized_dice_score.aggregate(reduction="sum_batch")
returns a single-element tensor but it should return a tensor with the size as number of classes containing dice scores summed across all batches for each class. Similar for mean_batch.
Case 3: generalized_dice_score.aggregate(reduction="none")
should give dice scores across batch and classes unreduced (similar to compute_generalized_dice
) but it gives values reduced over all classes.
Environment
Ensuring you use the relevant python executable, please paste the output of:
================================
Printing MONAI config...
================================
MONAI version: 1.3.2
Numpy version: 1.24.4
Pytorch version: 2.4.0+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /home/<username>/miniconda3/lib/python3.10/site-packages/monai/__init__.py
Suggested Solution
It looks like removing the .sum()
in the following code should give expected behaviour
MONAI/monai/metrics/generalized_dice.py
Lines 162 to 163 in 59a7211
MONAI/monai/metrics/generalized_dice.py
Line 170 in 59a7211
Tried it out and the shapes and values look as expected.