Skip to content

Commit b62d1e1

Browse files
heyufan1995pre-commit-ci[bot]KumoLiu
authored
Fix transpose and patch coords bug (#8047)
Fixes # . ### Description Fix the bug that causes wrong results in model zoo finetuning. Patch coords was not passed from sliding window to vista3d. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 <[email protected]> Signed-off-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]>
1 parent 1a8afd1 commit b62d1e1

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ zarr
4242
huggingface_hub
4343
pyamg>=5.0.0
4444
packaging
45+
polygraphy

monai/apps/vista3d/sampler.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import torch
2121
from torch import Tensor
2222

23-
__all__ = ["sample_prompt_pairs"]
24-
2523
ENABLE_SPECIAL = True
2624
SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
2725
MERGE_LIST = {
@@ -30,6 +28,8 @@
3028
132: [57], # overlap with trachea merge into airway
3129
}
3230

31+
__all__ = ["sample_prompt_pairs"]
32+
3333

3434
def _get_point_label(id: int) -> tuple[int, int]:
3535
if id in SPECIAL_INDEX and ENABLE_SPECIAL:
@@ -66,22 +66,29 @@ def sample_prompt_pairs(
6666
max_backprompt: int, max number of prompt from background.
6767
max_point: maximum number of points for each object.
6868
include_background: if include 0 into training prompt. If included, background 0 is treated
69-
the same as foreground. Always be False for multi-partial-dataset training. If needed,
70-
can be true for finetuning specific dataset, .
69+
the same as foreground and points will be sampled. Can be true only if user want to segment
70+
background 0 with point clicks, otherwise always be false.
7171
drop_label_prob: probability to drop label prompt.
7272
drop_point_prob: probability to drop point prompt.
7373
point_sampler: sampler to augment masks with supervoxel.
7474
point_sampler_kwargs: arguments for point_sampler.
7575
7676
Returns:
77-
label_prompt: [B, 1]. The classes used for training automatic segmentation.
78-
point: [B, N, 3]. The corresponding points for each class.
79-
Note that background label prompt requires matching point as well ([0,0,0] is used).
80-
point_label: [B, N]. The corresponding point labels for each point (negative or positive).
81-
-1 is used for padding the background label prompt and will be ignored.
82-
prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss.
83-
label_prompt can be None, and prompt_class is used to identify point classes.
77+
tuple:
78+
- label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for
79+
training automatic segmentation.
80+
- point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points
81+
for each class. Note that background label prompts require matching points as well
82+
(e.g., [0, 0, 0] is used).
83+
- point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point
84+
labels for each point (negative or positive). -1 is used for padding the background
85+
label prompt and will be ignored.
86+
- prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt
87+
for label indexing during training. If label_prompt is None, prompt_class is used to
88+
identify point classes.
89+
8490
"""
91+
8592
# class label number
8693
if not labels.shape[0] == 1:
8794
raise ValueError("only support batch size 1")

monai/networks/nets/vista3d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,11 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
336336
def forward(
337337
self,
338338
input_images: torch.Tensor,
339+
patch_coords: Sequence[slice] | None = None,
339340
point_coords: torch.Tensor | None = None,
340341
point_labels: torch.Tensor | None = None,
341342
class_vector: torch.Tensor | None = None,
342343
prompt_class: torch.Tensor | None = None,
343-
patch_coords: Sequence[slice] | None = None,
344344
labels: torch.Tensor | None = None,
345345
label_set: Sequence[int] | None = None,
346346
prev_mask: torch.Tensor | None = None,
@@ -421,7 +421,10 @@ def forward(
421421
point_coords, point_labels = None, None
422422

423423
if point_coords is None and class_vector is None:
424-
return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
424+
logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
425+
if transpose:
426+
logits = logits.transpose(1, 0)
427+
return logits
425428

426429
if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
427430
out, out_auto = self.image_embeddings, None

0 commit comments

Comments
 (0)