13
13
14
14
import itertools
15
15
from collections .abc import Sequence
16
- from typing import Final
17
16
18
17
import numpy as np
19
18
import torch
@@ -51,8 +50,6 @@ class SwinUNETR(nn.Module):
51
50
<https://arxiv.org/abs/2201.01266>"
52
51
"""
53
52
54
- patch_size : Final [int ] = 2
55
-
56
53
@deprecated_arg (
57
54
name = "img_size" ,
58
55
since = "1.3" ,
@@ -65,18 +62,24 @@ def __init__(
65
62
img_size : Sequence [int ] | int ,
66
63
in_channels : int ,
67
64
out_channels : int ,
65
+ patch_size : int = 2 ,
68
66
depths : Sequence [int ] = (2 , 2 , 2 , 2 ),
69
67
num_heads : Sequence [int ] = (3 , 6 , 12 , 24 ),
68
+ window_size : Sequence [int ] | int = 7 ,
69
+ qkv_bias : bool = True ,
70
+ mlp_ratio : float = 4.0 ,
70
71
feature_size : int = 24 ,
71
72
norm_name : tuple | str = "instance" ,
72
73
drop_rate : float = 0.0 ,
73
74
attn_drop_rate : float = 0.0 ,
74
75
dropout_path_rate : float = 0.0 ,
75
76
normalize : bool = True ,
77
+ norm_layer : type [LayerNorm ] = nn .LayerNorm ,
78
+ patch_norm : bool = True ,
76
79
use_checkpoint : bool = False ,
77
80
spatial_dims : int = 3 ,
78
- downsample = "merging" ,
79
- use_v2 = False ,
81
+ downsample : str | nn . Module = "merging" ,
82
+ use_v2 : bool = False ,
80
83
) -> None :
81
84
"""
82
85
Args:
@@ -86,14 +89,20 @@ def __init__(
86
89
It will be removed in an upcoming version.
87
90
in_channels: dimension of input channels.
88
91
out_channels: dimension of output channels.
92
+ patch_size: size of the patch token.
89
93
feature_size: dimension of network feature size.
90
94
depths: number of layers in each stage.
91
95
num_heads: number of attention heads.
96
+ window_size: local window size.
97
+ qkv_bias: add a learnable bias to query, key, value.
98
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
92
99
norm_name: feature normalization type and arguments.
93
100
drop_rate: dropout rate.
94
101
attn_drop_rate: attention dropout rate.
95
102
dropout_path_rate: drop path rate.
96
103
normalize: normalize output intermediate features in each stage.
104
+ norm_layer: normalization layer.
105
+ patch_norm: whether to apply normalization to the patch embedding.
97
106
use_checkpoint: use gradient checkpointing for reduced memory usage.
98
107
spatial_dims: number of spatial dims.
99
108
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
@@ -116,13 +125,15 @@ def __init__(
116
125
117
126
super ().__init__ ()
118
127
119
- img_size = ensure_tuple_rep (img_size , spatial_dims )
120
- patch_sizes = ensure_tuple_rep (self .patch_size , spatial_dims )
121
- window_size = ensure_tuple_rep (7 , spatial_dims )
122
-
123
128
if spatial_dims not in (2 , 3 ):
124
129
raise ValueError ("spatial dimension should be 2 or 3." )
125
130
131
+ self .patch_size = patch_size
132
+
133
+ img_size = ensure_tuple_rep (img_size , spatial_dims )
134
+ patch_sizes = ensure_tuple_rep (self .patch_size , spatial_dims )
135
+ window_size = ensure_tuple_rep (window_size , spatial_dims )
136
+
126
137
self ._check_input_size (img_size )
127
138
128
139
if not (0 <= drop_rate <= 1 ):
@@ -146,12 +157,13 @@ def __init__(
146
157
patch_size = patch_sizes ,
147
158
depths = depths ,
148
159
num_heads = num_heads ,
149
- mlp_ratio = 4.0 ,
150
- qkv_bias = True ,
160
+ mlp_ratio = mlp_ratio ,
161
+ qkv_bias = qkv_bias ,
151
162
drop_rate = drop_rate ,
152
163
attn_drop_rate = attn_drop_rate ,
153
164
drop_path_rate = dropout_path_rate ,
154
- norm_layer = nn .LayerNorm ,
165
+ norm_layer = norm_layer ,
166
+ patch_norm = patch_norm ,
155
167
use_checkpoint = use_checkpoint ,
156
168
spatial_dims = spatial_dims ,
157
169
downsample = look_up_option (downsample , MERGING_MODE ) if isinstance (downsample , str ) else downsample ,
0 commit comments