Skip to content

GeneralizedDiceScore reductions do not work as expected #7966

Closed
@surajpaib

Description

@surajpaib

Describe the bug
compute_generalized_dice does not return values as highlighted in the docs.

import torch
from monai.metrics import compute_generalized_dice, GeneralizedDiceScore

a = torch.ones((20, 10, 64, 128, 128))
b = torch.ones((20, 10, 64, 128, 128))


compute_generalized_dice(a, b).shape

This returns a shape of torch.Size([20]) whereas per the documentation it should return torch.Size([20, 10])
https://docs.monai.io/en/stable/metrics.html#generalized-dice-score

This leads to problems that propagate to GeneralizedDiceScore as below,

generalized_dice_score = GeneralizedDiceScore()
generalized_dice_score(a, b)

Aggregating this over different reductions shows unexpected results and errors.
Case 1: generalized_dice_score.aggregate(reduction="sum") and generalized_dice_score.aggregate(reduction="mean") return this error IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Case 2: generalized_dice_score.aggregate(reduction="sum_batch") returns a single-element tensor but it should return a tensor with the size as number of classes containing dice scores summed across all batches for each class. Similar for mean_batch.

Case 3: generalized_dice_score.aggregate(reduction="none") should give dice scores across batch and classes unreduced (similar to compute_generalized_dice) but it gives values reduced over all classes.

Environment

Ensuring you use the relevant python executable, please paste the output of:

================================
Printing MONAI config...
================================
MONAI version: 1.3.2
Numpy version: 1.24.4
Pytorch version: 2.4.0+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /home/<username>/miniconda3/lib/python3.10/site-packages/monai/__init__.py

Suggested Solution

It looks like removing the .sum() in the following code should give expected behaviour

numer = 2.0 * (intersection * w).sum(dim=1)
denom = (denominator * w).sum(dim=1)

y_pred_o = y_pred_o.sum(dim=-1)

Tried it out and the shapes and values look as expected.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions