|
20 | 20 | import torch
|
21 | 21 | from torch import Tensor
|
22 | 22 |
|
23 |
| -__all__ = ["sample_prompt_pairs"] |
24 |
| - |
25 | 23 | ENABLE_SPECIAL = True
|
26 | 24 | SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
|
27 | 25 | MERGE_LIST = {
|
|
30 | 28 | 132: [57], # overlap with trachea merge into airway
|
31 | 29 | }
|
32 | 30 |
|
| 31 | +__all__ = ["sample_prompt_pairs"] |
| 32 | + |
33 | 33 |
|
34 | 34 | def _get_point_label(id: int) -> tuple[int, int]:
|
35 | 35 | if id in SPECIAL_INDEX and ENABLE_SPECIAL:
|
@@ -66,22 +66,29 @@ def sample_prompt_pairs(
|
66 | 66 | max_backprompt: int, max number of prompt from background.
|
67 | 67 | max_point: maximum number of points for each object.
|
68 | 68 | include_background: if include 0 into training prompt. If included, background 0 is treated
|
69 |
| - the same as foreground. Always be False for multi-partial-dataset training. If needed, |
70 |
| - can be true for finetuning specific dataset, . |
| 69 | + the same as foreground and points will be sampled. Can be true only if user want to segment |
| 70 | + background 0 with point clicks, otherwise always be false. |
71 | 71 | drop_label_prob: probability to drop label prompt.
|
72 | 72 | drop_point_prob: probability to drop point prompt.
|
73 | 73 | point_sampler: sampler to augment masks with supervoxel.
|
74 | 74 | point_sampler_kwargs: arguments for point_sampler.
|
75 | 75 |
|
76 | 76 | Returns:
|
77 |
| - label_prompt: [B, 1]. The classes used for training automatic segmentation. |
78 |
| - point: [B, N, 3]. The corresponding points for each class. |
79 |
| - Note that background label prompt requires matching point as well ([0,0,0] is used). |
80 |
| - point_label: [B, N]. The corresponding point labels for each point (negative or positive). |
81 |
| - -1 is used for padding the background label prompt and will be ignored. |
82 |
| - prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. |
83 |
| - label_prompt can be None, and prompt_class is used to identify point classes. |
| 77 | + tuple: |
| 78 | + - label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for |
| 79 | + training automatic segmentation. |
| 80 | + - point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points |
| 81 | + for each class. Note that background label prompts require matching points as well |
| 82 | + (e.g., [0, 0, 0] is used). |
| 83 | + - point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point |
| 84 | + labels for each point (negative or positive). -1 is used for padding the background |
| 85 | + label prompt and will be ignored. |
| 86 | + - prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt |
| 87 | + for label indexing during training. If label_prompt is None, prompt_class is used to |
| 88 | + identify point classes. |
| 89 | +
|
84 | 90 | """
|
| 91 | + |
85 | 92 | # class label number
|
86 | 93 | if not labels.shape[0] == 1:
|
87 | 94 | raise ValueError("only support batch size 1")
|
|
0 commit comments