Skip to content

Commit 08f3ce3

Browse files
committed
Refactor SwinUNETR constructor to accept patch_size and additional parameters
1 parent 13b96ae commit 08f3ce3

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

monai/networks/nets/swin_unetr.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import itertools
1515
from collections.abc import Sequence
16-
from typing import Final
1716

1817
import numpy as np
1918
import torch
@@ -51,8 +50,6 @@ class SwinUNETR(nn.Module):
5150
<https://arxiv.org/abs/2201.01266>"
5251
"""
5352

54-
patch_size: Final[int] = 2
55-
5653
@deprecated_arg(
5754
name="img_size",
5855
since="1.3",
@@ -65,18 +62,24 @@ def __init__(
6562
img_size: Sequence[int] | int,
6663
in_channels: int,
6764
out_channels: int,
65+
patch_size: int = 2,
6866
depths: Sequence[int] = (2, 2, 2, 2),
6967
num_heads: Sequence[int] = (3, 6, 12, 24),
68+
window_size: Sequence[int] | int = 7,
69+
qkv_bias: bool = True,
70+
mlp_ratio: float = 4.0,
7071
feature_size: int = 24,
7172
norm_name: tuple | str = "instance",
7273
drop_rate: float = 0.0,
7374
attn_drop_rate: float = 0.0,
7475
dropout_path_rate: float = 0.0,
7576
normalize: bool = True,
77+
norm_layer: type[LayerNorm] = nn.LayerNorm,
78+
patch_norm: bool = True,
7679
use_checkpoint: bool = False,
7780
spatial_dims: int = 3,
78-
downsample="merging",
79-
use_v2=False,
81+
downsample: str | nn.Module = "merging",
82+
use_v2: bool = False,
8083
) -> None:
8184
"""
8285
Args:
@@ -86,14 +89,20 @@ def __init__(
8689
It will be removed in an upcoming version.
8790
in_channels: dimension of input channels.
8891
out_channels: dimension of output channels.
92+
patch_size: size of the patch token.
8993
feature_size: dimension of network feature size.
9094
depths: number of layers in each stage.
9195
num_heads: number of attention heads.
96+
window_size: local window size.
97+
qkv_bias: add a learnable bias to query, key, value.
98+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
9299
norm_name: feature normalization type and arguments.
93100
drop_rate: dropout rate.
94101
attn_drop_rate: attention dropout rate.
95102
dropout_path_rate: drop path rate.
96103
normalize: normalize output intermediate features in each stage.
104+
norm_layer: normalization layer.
105+
patch_norm: whether to apply normalization to the patch embedding.
97106
use_checkpoint: use gradient checkpointing for reduced memory usage.
98107
spatial_dims: number of spatial dims.
99108
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
@@ -116,13 +125,15 @@ def __init__(
116125

117126
super().__init__()
118127

119-
img_size = ensure_tuple_rep(img_size, spatial_dims)
120-
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
121-
window_size = ensure_tuple_rep(7, spatial_dims)
122-
123128
if spatial_dims not in (2, 3):
124129
raise ValueError("spatial dimension should be 2 or 3.")
125130

131+
self.patch_size = patch_size
132+
133+
img_size = ensure_tuple_rep(img_size, spatial_dims)
134+
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
135+
window_size = ensure_tuple_rep(window_size, spatial_dims)
136+
126137
self._check_input_size(img_size)
127138

128139
if not (0 <= drop_rate <= 1):
@@ -146,12 +157,13 @@ def __init__(
146157
patch_size=patch_sizes,
147158
depths=depths,
148159
num_heads=num_heads,
149-
mlp_ratio=4.0,
150-
qkv_bias=True,
160+
mlp_ratio=mlp_ratio,
161+
qkv_bias=qkv_bias,
151162
drop_rate=drop_rate,
152163
attn_drop_rate=attn_drop_rate,
153164
drop_path_rate=dropout_path_rate,
154-
norm_layer=nn.LayerNorm,
165+
norm_layer=norm_layer,
166+
patch_norm=patch_norm,
155167
use_checkpoint=use_checkpoint,
156168
spatial_dims=spatial_dims,
157169
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,

0 commit comments

Comments
 (0)