From b554a104d2546d44a6da3e72847da7e65f7a3844 Mon Sep 17 00:00:00 2001 From: Billy Luu Date: Thu, 1 May 2025 14:15:09 -0400 Subject: [PATCH 1/4] Add U-Mamba architecture to MONAI --- networks/nets/__init__.py | 147 +++++++++++++++++++++++++++++++++ networks/nets/u_mamba.py | 99 ++++++++++++++++++++++ tests/test_networks_u_mamba.py | 26 ++++++ 3 files changed, 272 insertions(+) create mode 100644 networks/nets/__init__.py create mode 100644 networks/nets/u_mamba.py create mode 100644 tests/test_networks_u_mamba.py diff --git a/networks/nets/__init__.py b/networks/nets/__init__.py new file mode 100644 index 0000000000..0bf159a851 --- /dev/null +++ b/networks/nets/__init__.py @@ -0,0 +1,147 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .ahnet import AHnet, Ahnet, AHNet +from .attentionunet import AttentionUnet +from .autoencoder import AutoEncoder +from .autoencoderkl import AutoencoderKL +from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet +from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus +from .classifier import Classifier, Critic, Discriminator +from .controlnet import ControlNet +from .daf3d import DAF3D +from .densenet import ( + DenseNet, + Densenet, + DenseNet121, + Densenet121, + DenseNet169, + Densenet169, + DenseNet201, + Densenet201, + DenseNet264, + Densenet264, + densenet121, + densenet169, + densenet201, + densenet264, +) +from .diffusion_model_unet import DiffusionModelUNet +from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch +from .dynunet import DynUNet, DynUnet, Dynunet +from .efficientnet import ( + BlockArgs, + EfficientNet, + EfficientNetBN, + EfficientNetBNFeatures, + EfficientNetEncoder, + drop_connect, + get_efficientnet_image_size, +) +from .flexible_unet import FLEXUNET_BACKBONE, FlexibleUNet, FlexUNet, FlexUNetEncoderRegister +from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet +from .generator import Generator +from .highresnet import HighResBlock, HighResNet +from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .masked_autoencoder_vit import MaskedAutoEncoderViT +from .mednext import ( + MedNeXt, + MedNext, + MedNextB, + MedNeXtB, + MedNextBase, + MedNextL, + MedNeXtL, + MedNeXtLarge, + MedNextLarge, + MedNextM, + MedNeXtM, + MedNeXtMedium, + MedNextMedium, + MedNextS, + MedNeXtS, + MedNeXtSmall, + MedNextSmall, +) +from .milmodel import MILModel +from .netadapter import NetAdapter +from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator +from .quicknat import Quicknat +from .regressor import Regressor +from .regunet import GlobalNet, LocalNet, RegUNet +from .resnet import ( + ResNet, + ResNetBlock, + ResNetBottleneck, + ResNetEncoder, + ResNetFeatures, + get_medicalnet_pretrained_resnet_args, + get_pretrained_resnet_medicalnet, + resnet10, + resnet18, + resnet34, + resnet50, + resnet101, + resnet152, + resnet200, +) +from .segresnet import SegResNet, SegResNetVAE +from .segresnet_ds import SegResNetDS, SegResNetDS2 +from .senet import ( + SENet, + SEnet, + Senet, + SENet154, + SEnet154, + Senet154, + SEResNet50, + SEresnet50, + Seresnet50, + SEResNet101, + SEresnet101, + Seresnet101, + SEResNet152, + SEresnet152, + Seresnet152, + SEResNext50, + SEResNeXt50, + SEresnext50, + Seresnext50, + SEResNext101, + SEResNeXt101, + SEresnext101, + Seresnext101, + senet154, + seresnet50, + seresnet101, + seresnet152, + seresnext50, + seresnext101, +) +from .spade_autoencoderkl import SPADEAutoencoderKL +from .spade_diffusion_model_unet import SPADEDiffusionModelUNet +from .spade_network import SPADENet +from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR +from .torchvision_fc import TorchVisionFCModel +from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .transformer import DecoderOnlyTransformer +from .unet import UNet, Unet +from .unetr import UNETR +from .varautoencoder import VarAutoEncoder +from .vista3d import VISTA3D, vista3d132 +from .vit import ViT +from .vitautoenc import ViTAutoEnc +from .vnet import VNet +from .voxelmorph import VoxelMorph, VoxelMorphUNet +from .vqvae import VQVAE +from .u_mamba import UMamba diff --git a/networks/nets/u_mamba.py b/networks/nets/u_mamba.py new file mode 100644 index 0000000000..1d37be242d --- /dev/null +++ b/networks/nets/u_mamba.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Simple placeholder for the SSM (Mamba-like block) +class SSMBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear1 = nn.Linear(dim, dim) + self.linear2 = nn.Linear(dim, dim) + + def forward(self, x): + # x: (B, L, C) + return self.linear2(torch.silu(self.linear1(x))) + +class UMambaBlock(nn.Module): + def __init__(self, in_channels, hidden_channels): + super().__init__() + self.conv_res1 = nn.Sequential( + nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm3d(in_channels), + nn.LeakyReLU(), + ) + self.conv_res2 = nn.Sequential( + nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm3d(in_channels), + nn.LeakyReLU(), + ) + + self.layernorm = nn.LayerNorm(hidden_channels) + self.linear1 = nn.Linear(in_channels, hidden_channels) + self.linear2 = nn.Linear(hidden_channels, in_channels) + self.conv1d = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1) + self.ssm = SSMBlock(hidden_channels) + + def forward(self, x): + # x: (B, C, H, W, D) + residual = x + x = self.conv_res1(x) + x = self.conv_res2(x) + residual + + B, C, H, W, D = x.shape + x_flat = x.view(B, C, -1).permute(0, 2, 1) # (B, L, C) + x_norm = self.layernorm(x_flat) + x_proj = self.linear1(x_norm) + + x_silu = torch.silu(x_proj) + x_ssm = self.ssm(x_silu) + x_conv1d = self.conv1d(x_proj.permute(0, 2, 1)).permute(0, 2, 1) + + x_combined = torch.silu(x_conv1d) * torch.silu(x_ssm) + x_out = self.linear2(x_combined) + x_out = x_out.permute(0, 2, 1).view(B, C, H, W, D) + + return x + x_out # Residual connection + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.block = nn.Sequential( + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + nn.ReLU(), + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + ) + + def forward(self, x): + return F.relu(x + self.block(x)) + +class UMambaUNet(nn.Module): + def __init__(self, in_channels=1, out_channels=1, base_channels=32): + super().__init__() + self.enc1 = UMambaBlock(in_channels, base_channels) + self.down1 = nn.Conv3d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1) + + self.enc2 = UMambaBlock(base_channels*2, base_channels*2) + self.down2 = nn.Conv3d(base_channels*2, base_channels*4, kernel_size=3, stride=2, padding=1) + + self.bottleneck = UMambaBlock(base_channels*4, base_channels*4) + + self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=2, stride=2) + self.dec2 = ResidualBlock(base_channels*4) + + self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=2, stride=2) + self.dec1 = ResidualBlock(base_channels*2) + + self.final = nn.Conv3d(base_channels, out_channels, kernel_size=1) + + def forward(self, x): + x1 = self.enc1(x) + x2 = self.enc2(self.down1(x1)) + x3 = self.bottleneck(self.down2(x2)) + + x = self.up2(x3) + x = self.dec2(torch.cat([x, x2], dim=1)) + x = self.up1(x) + x = self.dec1(torch.cat([x, x1], dim=1)) + return self.final(x) diff --git a/tests/test_networks_u_mamba.py b/tests/test_networks_u_mamba.py new file mode 100644 index 0000000000..0e0a1f4599 --- /dev/null +++ b/tests/test_networks_u_mamba.py @@ -0,0 +1,26 @@ +import unittest +import torch +from monai.networks.nets import UMamba + +class TestUMamba(unittest.TestCase): + def test_forward_shape(self): + # Set up input dimensions and model + input_tensor = torch.randn(2, 1, 64, 64) # (batch_size, channels, H, W) + model = UMamba(in_channels=1, out_channels=2) # example args + + # Forward pass + output = model(input_tensor) + + # Assert output shape matches expectation + self.assertEqual(output.shape, (2, 2, 64, 64)) # adjust if necessary + + def test_script(self): + # Test JIT scripting if supported + model = UMamba(in_channels=1, out_channels=2) + scripted = torch.jit.script(model) + x = torch.randn(1, 1, 64, 64) + out = scripted(x) + self.assertEqual(out.shape, (1, 2, 64, 64)) + +if __name__ == "__main__": + unittest.main() From 3ab921380adc3d906af3b42e2ba0a44fd832d02e Mon Sep 17 00:00:00 2001 From: Billy Luu Date: Thu, 1 May 2025 14:55:38 -0400 Subject: [PATCH 2/4] Fix test input shape and class name for UMambaUNet --- networks/nets/__init__.py | 1 - networks/nets/u_mamba.py | 99 ---------------------------------- tests/test_networks_u_mamba.py | 14 ++--- 3 files changed, 5 insertions(+), 109 deletions(-) delete mode 100644 networks/nets/u_mamba.py diff --git a/networks/nets/__init__.py b/networks/nets/__init__.py index 0bf159a851..c1917e5293 100644 --- a/networks/nets/__init__.py +++ b/networks/nets/__init__.py @@ -144,4 +144,3 @@ from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet from .vqvae import VQVAE -from .u_mamba import UMamba diff --git a/networks/nets/u_mamba.py b/networks/nets/u_mamba.py deleted file mode 100644 index 1d37be242d..0000000000 --- a/networks/nets/u_mamba.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -# Simple placeholder for the SSM (Mamba-like block) -class SSMBlock(nn.Module): - def __init__(self, dim): - super().__init__() - self.linear1 = nn.Linear(dim, dim) - self.linear2 = nn.Linear(dim, dim) - - def forward(self, x): - # x: (B, L, C) - return self.linear2(torch.silu(self.linear1(x))) - -class UMambaBlock(nn.Module): - def __init__(self, in_channels, hidden_channels): - super().__init__() - self.conv_res1 = nn.Sequential( - nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1), - nn.InstanceNorm3d(in_channels), - nn.LeakyReLU(), - ) - self.conv_res2 = nn.Sequential( - nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1), - nn.InstanceNorm3d(in_channels), - nn.LeakyReLU(), - ) - - self.layernorm = nn.LayerNorm(hidden_channels) - self.linear1 = nn.Linear(in_channels, hidden_channels) - self.linear2 = nn.Linear(hidden_channels, in_channels) - self.conv1d = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1) - self.ssm = SSMBlock(hidden_channels) - - def forward(self, x): - # x: (B, C, H, W, D) - residual = x - x = self.conv_res1(x) - x = self.conv_res2(x) + residual - - B, C, H, W, D = x.shape - x_flat = x.view(B, C, -1).permute(0, 2, 1) # (B, L, C) - x_norm = self.layernorm(x_flat) - x_proj = self.linear1(x_norm) - - x_silu = torch.silu(x_proj) - x_ssm = self.ssm(x_silu) - x_conv1d = self.conv1d(x_proj.permute(0, 2, 1)).permute(0, 2, 1) - - x_combined = torch.silu(x_conv1d) * torch.silu(x_ssm) - x_out = self.linear2(x_combined) - x_out = x_out.permute(0, 2, 1).view(B, C, H, W, D) - - return x + x_out # Residual connection - -class ResidualBlock(nn.Module): - def __init__(self, channels): - super().__init__() - self.block = nn.Sequential( - nn.Conv3d(channels, channels, kernel_size=3, padding=1), - nn.BatchNorm3d(channels), - nn.ReLU(), - nn.Conv3d(channels, channels, kernel_size=3, padding=1), - nn.BatchNorm3d(channels), - ) - - def forward(self, x): - return F.relu(x + self.block(x)) - -class UMambaUNet(nn.Module): - def __init__(self, in_channels=1, out_channels=1, base_channels=32): - super().__init__() - self.enc1 = UMambaBlock(in_channels, base_channels) - self.down1 = nn.Conv3d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1) - - self.enc2 = UMambaBlock(base_channels*2, base_channels*2) - self.down2 = nn.Conv3d(base_channels*2, base_channels*4, kernel_size=3, stride=2, padding=1) - - self.bottleneck = UMambaBlock(base_channels*4, base_channels*4) - - self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=2, stride=2) - self.dec2 = ResidualBlock(base_channels*4) - - self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=2, stride=2) - self.dec1 = ResidualBlock(base_channels*2) - - self.final = nn.Conv3d(base_channels, out_channels, kernel_size=1) - - def forward(self, x): - x1 = self.enc1(x) - x2 = self.enc2(self.down1(x1)) - x3 = self.bottleneck(self.down2(x2)) - - x = self.up2(x3) - x = self.dec2(torch.cat([x, x2], dim=1)) - x = self.up1(x) - x = self.dec1(torch.cat([x, x1], dim=1)) - return self.final(x) diff --git a/tests/test_networks_u_mamba.py b/tests/test_networks_u_mamba.py index 0e0a1f4599..02f75226da 100644 --- a/tests/test_networks_u_mamba.py +++ b/tests/test_networks_u_mamba.py @@ -1,22 +1,18 @@ import unittest import torch -from monai.networks.nets import UMamba +from monai.networks.nets import UMambaUNet class TestUMamba(unittest.TestCase): def test_forward_shape(self): # Set up input dimensions and model - input_tensor = torch.randn(2, 1, 64, 64) # (batch_size, channels, H, W) - model = UMamba(in_channels=1, out_channels=2) # example args - - # Forward pass + input_tensor = torch.randn(2, 1, 16, 64, 64) + model = UMambaUNet(in_channels=1, out_channels=2) output = model(input_tensor) - - # Assert output shape matches expectation - self.assertEqual(output.shape, (2, 2, 64, 64)) # adjust if necessary + self.assertEqual(output.shape, (2, 2, 16, 64, 64)) def test_script(self): # Test JIT scripting if supported - model = UMamba(in_channels=1, out_channels=2) + model = UMambaUNet(in_channels=1, out_channels=2) scripted = torch.jit.script(model) x = torch.randn(1, 1, 64, 64) out = scripted(x) From 9b190cfddb601094776f049ebc66dfa779a7aae4 Mon Sep 17 00:00:00 2001 From: Billy Luu <114883074+billyluu5704@users.noreply.github.com> Date: Thu, 1 May 2025 14:57:07 -0400 Subject: [PATCH 3/4] Delete networks/nets directory Signed-off-by: Billy Luu <114883074+billyluu5704@users.noreply.github.com> --- networks/nets/__init__.py | 146 -------------------------------------- 1 file changed, 146 deletions(-) delete mode 100644 networks/nets/__init__.py diff --git a/networks/nets/__init__.py b/networks/nets/__init__.py deleted file mode 100644 index c1917e5293..0000000000 --- a/networks/nets/__init__.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from .ahnet import AHnet, Ahnet, AHNet -from .attentionunet import AttentionUnet -from .autoencoder import AutoEncoder -from .autoencoderkl import AutoencoderKL -from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet -from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus -from .classifier import Classifier, Critic, Discriminator -from .controlnet import ControlNet -from .daf3d import DAF3D -from .densenet import ( - DenseNet, - Densenet, - DenseNet121, - Densenet121, - DenseNet169, - Densenet169, - DenseNet201, - Densenet201, - DenseNet264, - Densenet264, - densenet121, - densenet169, - densenet201, - densenet264, -) -from .diffusion_model_unet import DiffusionModelUNet -from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch -from .dynunet import DynUNet, DynUnet, Dynunet -from .efficientnet import ( - BlockArgs, - EfficientNet, - EfficientNetBN, - EfficientNetBNFeatures, - EfficientNetEncoder, - drop_connect, - get_efficientnet_image_size, -) -from .flexible_unet import FLEXUNET_BACKBONE, FlexibleUNet, FlexUNet, FlexUNetEncoderRegister -from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet -from .generator import Generator -from .highresnet import HighResBlock, HighResNet -from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet -from .masked_autoencoder_vit import MaskedAutoEncoderViT -from .mednext import ( - MedNeXt, - MedNext, - MedNextB, - MedNeXtB, - MedNextBase, - MedNextL, - MedNeXtL, - MedNeXtLarge, - MedNextLarge, - MedNextM, - MedNeXtM, - MedNeXtMedium, - MedNextMedium, - MedNextS, - MedNeXtS, - MedNeXtSmall, - MedNextSmall, -) -from .milmodel import MILModel -from .netadapter import NetAdapter -from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator -from .quicknat import Quicknat -from .regressor import Regressor -from .regunet import GlobalNet, LocalNet, RegUNet -from .resnet import ( - ResNet, - ResNetBlock, - ResNetBottleneck, - ResNetEncoder, - ResNetFeatures, - get_medicalnet_pretrained_resnet_args, - get_pretrained_resnet_medicalnet, - resnet10, - resnet18, - resnet34, - resnet50, - resnet101, - resnet152, - resnet200, -) -from .segresnet import SegResNet, SegResNetVAE -from .segresnet_ds import SegResNetDS, SegResNetDS2 -from .senet import ( - SENet, - SEnet, - Senet, - SENet154, - SEnet154, - Senet154, - SEResNet50, - SEresnet50, - Seresnet50, - SEResNet101, - SEresnet101, - Seresnet101, - SEResNet152, - SEresnet152, - Seresnet152, - SEResNext50, - SEResNeXt50, - SEresnext50, - Seresnext50, - SEResNext101, - SEResNeXt101, - SEresnext101, - Seresnext101, - senet154, - seresnet50, - seresnet101, - seresnet152, - seresnext50, - seresnext101, -) -from .spade_autoencoderkl import SPADEAutoencoderKL -from .spade_diffusion_model_unet import SPADEDiffusionModelUNet -from .spade_network import SPADENet -from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR -from .torchvision_fc import TorchVisionFCModel -from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex -from .transformer import DecoderOnlyTransformer -from .unet import UNet, Unet -from .unetr import UNETR -from .varautoencoder import VarAutoEncoder -from .vista3d import VISTA3D, vista3d132 -from .vit import ViT -from .vitautoenc import ViTAutoEnc -from .vnet import VNet -from .voxelmorph import VoxelMorph, VoxelMorphUNet -from .vqvae import VQVAE From 49648504502e8f5e5c6e55a5a296444ad2eb1ba1 Mon Sep 17 00:00:00 2001 From: Billy Luu Date: Thu, 1 May 2025 15:05:30 -0400 Subject: [PATCH 4/4] Register UMambaUNet in monai.networks.nets.__init__ --- monai/networks/nets/__init__.py | 1 + monai/networks/nets/u_mamba.py | 110 ++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 monai/networks/nets/u_mamba.py diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c1917e5293..0d02f95efa 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -144,3 +144,4 @@ from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet from .vqvae import VQVAE +from .u_mamba import UMambaUNet diff --git a/monai/networks/nets/u_mamba.py b/monai/networks/nets/u_mamba.py new file mode 100644 index 0000000000..4697ad4a66 --- /dev/null +++ b/monai/networks/nets/u_mamba.py @@ -0,0 +1,110 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Simple placeholder for the SSM (Mamba-like block) +class SSMBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear1 = nn.Linear(dim, dim) + self.linear2 = nn.Linear(dim, dim) + + def forward(self, x): + # x: (B, L, C) + return self.linear2(torch.silu(self.linear1(x))) + +class UMambaBlock(nn.Module): + def __init__(self, in_channels, hidden_channels): + super().__init__() + self.conv_res1 = nn.Sequential( + nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm3d(in_channels), + nn.LeakyReLU(), + ) + self.conv_res2 = nn.Sequential( + nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm3d(in_channels), + nn.LeakyReLU(), + ) + + self.layernorm = nn.LayerNorm(hidden_channels) + self.linear1 = nn.Linear(in_channels, hidden_channels) + self.linear2 = nn.Linear(hidden_channels, in_channels) + self.conv1d = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1) + self.ssm = SSMBlock(hidden_channels) + + def forward(self, x): + # x: (B, C, H, W, D) + residual = x + x = self.conv_res1(x) + x = self.conv_res2(x) + residual + + B, C, H, W, D = x.shape + x_flat = x.view(B, C, -1).permute(0, 2, 1) # (B, L, C) + x_norm = self.layernorm(x_flat) + x_proj = self.linear1(x_norm) + + x_silu = torch.silu(x_proj) + x_ssm = self.ssm(x_silu) + x_conv1d = self.conv1d(x_proj.permute(0, 2, 1)).permute(0, 2, 1) + + x_combined = torch.silu(x_conv1d) * torch.silu(x_ssm) + x_out = self.linear2(x_combined) + x_out = x_out.permute(0, 2, 1).view(B, C, H, W, D) + + return x + x_out # Residual connection + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.block = nn.Sequential( + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + nn.ReLU(), + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + ) + + def forward(self, x): + return F.relu(x + self.block(x)) + +class UMambaUNet(nn.Module): + def __init__(self, in_channels=1, out_channels=1, base_channels=32): + super().__init__() + self.enc1 = UMambaBlock(in_channels, base_channels) + self.down1 = nn.Conv3d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1) + + self.enc2 = UMambaBlock(base_channels*2, base_channels*2) + self.down2 = nn.Conv3d(base_channels*2, base_channels*4, kernel_size=3, stride=2, padding=1) + + self.bottleneck = UMambaBlock(base_channels*4, base_channels*4) + + self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=2, stride=2) + self.dec2 = ResidualBlock(base_channels*4) + + self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=2, stride=2) + self.dec1 = ResidualBlock(base_channels*2) + + self.final = nn.Conv3d(base_channels, out_channels, kernel_size=1) + + def forward(self, x): + x1 = self.enc1(x) + x2 = self.enc2(self.down1(x1)) + x3 = self.bottleneck(self.down2(x2)) + + x = self.up2(x3) + x = self.dec2(torch.cat([x, x2], dim=1)) + x = self.up1(x) + x = self.dec1(torch.cat([x, x1], dim=1)) + return self.final(x)