From fc220e1497fd42c862b1e830c013db2d6d06601d Mon Sep 17 00:00:00 2001 From: kaibo Date: Tue, 31 Oct 2023 22:13:13 -0400 Subject: [PATCH 01/11] Implemented voxelmorph and passed coding style check Signed-off-by: kaibo --- monai/networks/nets/__init__.py | 1 + monai/networks/nets/voxelmorph.py | 367 ++++++++++++++++++++++++++++++ 2 files changed, 368 insertions(+) create mode 100644 monai/networks/nets/voxelmorph.py diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 1fb0f08ccc..8064b815a3 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -111,3 +111,4 @@ from .vit import ViT from .vitautoenc import ViTAutoEnc from .vnet import VNet +from .voxelmorph import VoxelMorph diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py new file mode 100644 index 0000000000..6118acad38 --- /dev/null +++ b/monai/networks/nets/voxelmorph.py @@ -0,0 +1,367 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.convolutions import Convolution +from monai.networks.blocks.upsample import UpSample +from monai.networks.blocks.warp import DVF2DDF, Warp +from monai.networks.layers.simplelayers import SkipConnection +from monai.utils import alias, export + +__all__ = ["VoxelMorph", "voxelmorph"] + + +@export("monai.networks.nets") +@alias("voxelmorph") +class VoxelMorph(nn.Module): + """ + Implementation of VoxelMorph network as described in https://arxiv.org/pdf/1809.05231.pdf. + + Overview. A pair of images (moving and fixed) are concatenated along the channel dimension and passed through + a UNet. The output of the UNet is then passed through a series of convolution blocks to produce the final prediction + of the displacement field (DDF) in the non-diffeomorphic variant (i.e. when `int_steps` is set to 0) or the + stationary velocity field (DVF) in the diffeomorphic variant (i.e. when `int_steps` is set to a positive integer). + The DVF is then converted to a DDF using the `DVF2DDF` module. Finally, the DDF is used to warp the moving image + to the fixed image using the `Warp` module. Optionally, the integration from DVF to DDF can be performed on reduced + resolution by specifying `half_res` to be True. + + Args: + in_shape: shape of the input image. + in_channels: number of channels in the input volume after concatenation of moving and fixed images. + unet_out_channels: number of channels in the output of the UNet. + channels: number of channels in each layer of the UNet. See the following example for more details. + final_conv_channels: number of channels in each layer of the final convolution block. + final_conv_act: activation type for the final convolution block. Defaults to LeakyReLU. + Since VoxelMorph was originally implemented in tensorflow where the default negative slope for LeakyReLU + was 0.2, we use the same default value here. + int_steps: number of integration steps. Defaults to 7. If set to 0, the network will be non-diffeomorphic. + kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3. + up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3. + spatial_dims: number of spatial dimensions. Defaults to 3. + act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2. + norm: feature normalization type and arguments for all convolution layers in the UNet. Defaults to None. + dropout: dropout ratio for all convolution layers in the UNet. Defaults to 0.0 (no dropout). + bias: whether to use bias in all convolution layers in the UNet. Defaults to True. + half_res: whether to perform integration on half resolution. Defaults to False. + use_maxpool: whether to use maxpooling in the downsampling path of the UNet. Defaults to True. + Using maxpooling is the consistent with the original implementation of VoxelMorph. + But one can optionally use strided convolution instead (i.e. set `use_maxpool` to False). + adn_ordering: ordering of activation, dropout, and normalization. Defaults to "NDA". + + Examples:: + + from monai.networks.nets import VoxelMorph + + # VoxelMorph network as it is in the original paper https://arxiv.org/pdf/1809.05231.pdf + net = VoxelMorph( + in_shape=(160, 192, 224), + in_channels=2, + unet_out_channels=32, + channels=(16, 32, 32, 32, 32, 32), # this indicates the down block at the top takes 16 channels as + # input, the corresponding up block at the top produces 32 + # channels as output, the second down block takes 32 channels as + # input, and the corresponding up block at the same level + # produces 32 channels as output, etc. + final_conv_channels=(16, 16) + ) + + # A forward pass through the network would look something like this + moving = torch.randn(1, 2, 160, 192, 224) + fixed = torch.randn(1, 2, 160, 192, 224) + warped, ddf = net(moving, fixed) + """ + + def __init__( + self, + in_shape: Sequence[int], + in_channels: int, + unet_out_channels: int, + channels: Sequence[int], + final_conv_channels: Sequence[int], + final_conv_act: tuple | str | None = "LEAKYRELU", + int_steps: int = 7, + kernel_size: Sequence[int] | int = 3, + up_kernel_size: Sequence[int] | int = 3, + spatial_dims: int = 3, + act: tuple | str = "LEAKYRELU", + norm: tuple | str | None = None, + dropout: float = 0.0, + bias: bool = True, + half_res: bool = False, + use_maxpool: bool = True, + adn_ordering: str = "NDA", + ) -> None: + super().__init__() + + if isinstance(kernel_size, Sequence) and len(kernel_size) != spatial_dims: + raise ValueError("the length of `kernel_size` should equal to `dimensions`.") + if isinstance(up_kernel_size, Sequence) and len(up_kernel_size) != spatial_dims: + raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") + + # UNet args + self.dimensions = spatial_dims + self.in_channels = in_channels + self.channels = channels + self.kernel_size = kernel_size + self.up_kernel_size = up_kernel_size + self.act = ("leakyrelu", {"negative_slope": 0.2, "inplace": True}) if act.upper() == "LEAKYRELU" else act + self.norm = norm + self.dropout = dropout + self.bias = bias + self.adn_ordering = adn_ordering + + # VoxelMorph specific args + self.in_shape = in_shape + self.unet_out_channels = unet_out_channels + self.half_res = half_res + self.use_maxpool = use_maxpool + + # final convolutions args + self.final_conv_channels = final_conv_channels + self.final_conv_act = ( + ("leakyrelu", {"negative_slope": 0.2, "inplace": True}) + if final_conv_act.upper() == "LEAKYRELU" + else final_conv_act + ) + + # integration args + self.int_steps = int_steps + self.diffeomorphic = True if self.int_steps > 0 else False + + def _create_block(inc: int, outc: int, channels: Sequence[int], is_top: bool) -> nn.Module: + """ + Builds the UNet structure recursively. + + Args: + inc: number of input channels. + outc: number of output channels. + channels: sequence of channels for each pair of down and up layers. + is_top: True if this is the top block. + """ + + next_c_in, next_c_out = channels[0:2] + upc = next_c_in + next_c_out + + subblock: nn.Module + + if len(channels) > 2: + subblock = _create_block(next_c_in, next_c_out, channels[2:], is_top=False) # continue recursion down + else: + # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer + subblock = self._get_bottom_layer(next_c_in, next_c_out) + + down = self._get_down_layer(inc, next_c_in, is_top) # create layer in downsampling path + up = self._get_up_layer(upc, outc, is_top) # create layer in upsampling path + + return self._get_connection_block(down, up, subblock) + + self.unet = _create_block(in_channels, unet_out_channels, self.channels, is_top=True) + + def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Module: + """ + Builds the final convolution blocks. + + Args: + inc: number of input channels, should be the same as `unet_out_channels`. + outc: number of output channels, should be 3 for 3D volume registration. + channels: sequence of channels for each convolution layer. + + Note: there is no activation after the last convolution layer as per the original implementation. + """ + + mod: nn.Module = nn.Sequential() + + for i, c in enumerate(channels): + mod.add_module( + f"final_conv_{i}", + Convolution( + self.dimensions, + inc, + c, + kernel_size=self.kernel_size, + act=self.final_conv_act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ), + ) + inc = c + + mod.add_module( + "final_conv_out", + Convolution( + self.dimensions, + inc, + outc, + kernel_size=self.kernel_size, + act=None, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ), + ) + + return mod + + self.final_conv = _create_final_conv(unet_out_channels, 3, self.final_conv_channels) + + # create helpers + self.dvf2ddf = DVF2DDF(num_steps=self.int_steps, mode="bilinear", padding_mode="zeros") + self.warp = Warp(mode="bilinear", padding_mode="zeros") + + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + """ + Returns the block object defining a layer of the UNet structure including the implementation of the skip + between encoding (down) and decoding (up) sides of the network. + + Args: + down_path: encoding half of the layer + up_path: decoding half of the layer + subblock: block defining the next layer in the network. + Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` + """ + + return nn.Sequential(down_path, SkipConnection(subblock), up_path) + + def _get_down_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn.Module: + """ + In each down layer, the input is first downsampled using maxpooling, + then passed through a convolution block, unless this is the top layer + in which case the input is passed through a convolution block only + without maxpooling first. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + is_top: True if this is the top block. + """ + + mod: Convolution | nn.Sequential + + strides = 1 if self.use_maxpool else 2 + + mod = Convolution( + self.dimensions, + in_channels, + out_channels, + strides=strides, + kernel_size=self.kernel_size, + act=self.act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ) + + if self.use_maxpool and not is_top: + mod = nn.Sequential(nn.MaxPool3d(kernel_size=2, stride=2), mod) + + return mod + + def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: + """ + Bottom layer (bottleneck) in voxelmorph consists of a typical down layer followed by an upsample layer. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + """ + + mod: nn.Module + upsample: nn.Module + + mod = self._get_down_layer(in_channels, out_channels, is_top=False) + + upsample = UpSample( + self.dimensions, + out_channels, + out_channels, + scale_factor=2, + mode="nontrainable", + interp_mode="nearest", + align_corners=None, # required to use with interp_mode="nearest" + ) + + return nn.Sequential(mod, upsample) + + def _get_up_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn.Module: + """ + In each up layer, the input is passed through a convolution block before upsampled, + unless this is the top layer in which case the input is passed through a convolution block only + without upsampling. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + is_top: True if this is the top block. + """ + + mod: Convolution | nn.Sequential + + strides = 1 + + mod = Convolution( + self.dimensions, + in_channels, + out_channels, + strides=strides, + kernel_size=self.up_kernel_size, + act=self.act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + # conv_only=is_top, + is_transposed=False, + adn_ordering=self.adn_ordering, + ) + + if not is_top: + mod = nn.Sequential( + mod, + UpSample( + self.dimensions, + out_channels, + out_channels, + scale_factor=2, + mode="nontrainable", + interp_mode="nearest", + align_corners=None, # required to use with interp_mode="nearest" + ), + ) + + return mod + + def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> (torch.Tensor, torch.Tensor): + x = self.unet(torch.cat([moving, fixed], dim=1)) + x = self.final_conv(x) + + if self.half_res: + x = F.interpolate(x, scale_factor=0.5, mode="trilinear", align_corners=True) * 2.0 + + if self.diffeomorphic: + x = self.dvf2ddf(x) + + if self.half_res: + x = F.interpolate(x * 0.5, scale_factor=2.0, mode="trilinear", align_corners=True) + + return self.warp(moving, x), x + + +voxelmorph = VoxelMorph From bf69c7c53c6bb46e82005ca507be2e89d53abc05 Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 1 Nov 2023 18:11:34 -0400 Subject: [PATCH 02/11] Added and passed unittest, fixed static type test error Signed-off-by: kaibo --- monai/networks/nets/voxelmorph.py | 42 ++-- tests/test_voxelmorph.py | 310 ++++++++++++++++++++++++++++++ 2 files changed, 338 insertions(+), 14 deletions(-) create mode 100644 tests/test_voxelmorph.py diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index 6118acad38..d707af746f 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -41,7 +41,7 @@ class VoxelMorph(nn.Module): resolution by specifying `half_res` to be True. Args: - in_shape: shape of the input image. + spatial_dims: number of spatial dimensions. in_channels: number of channels in the input volume after concatenation of moving and fixed images. unet_out_channels: number of channels in the output of the UNet. channels: number of channels in each layer of the UNet. See the following example for more details. @@ -52,7 +52,6 @@ class VoxelMorph(nn.Module): int_steps: number of integration steps. Defaults to 7. If set to 0, the network will be non-diffeomorphic. kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3. up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3. - spatial_dims: number of spatial dimensions. Defaults to 3. act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2. norm: feature normalization type and arguments for all convolution layers in the UNet. Defaults to None. dropout: dropout ratio for all convolution layers in the UNet. Defaults to 0.0 (no dropout). @@ -69,7 +68,7 @@ class VoxelMorph(nn.Module): # VoxelMorph network as it is in the original paper https://arxiv.org/pdf/1809.05231.pdf net = VoxelMorph( - in_shape=(160, 192, 224), + spatial_dims=3, in_channels=2, unet_out_channels=32, channels=(16, 32, 32, 32, 32, 32), # this indicates the down block at the top takes 16 channels as @@ -88,7 +87,7 @@ class VoxelMorph(nn.Module): def __init__( self, - in_shape: Sequence[int], + spatial_dims: int, in_channels: int, unet_out_channels: int, channels: Sequence[int], @@ -97,7 +96,6 @@ def __init__( int_steps: int = 7, kernel_size: Sequence[int] | int = 3, up_kernel_size: Sequence[int] | int = 3, - spatial_dims: int = 3, act: tuple | str = "LEAKYRELU", norm: tuple | str | None = None, dropout: float = 0.0, @@ -108,6 +106,14 @@ def __init__( ) -> None: super().__init__() + if spatial_dims not in (2, 3): + raise ValueError("spatial_dims must be either 2 or 3.") + if in_channels % 2 != 0: + raise ValueError("in_channels must be divisible by 2.") + if len(channels) < 2: + raise ValueError("the length of `channels` should be no less than 2.") + if len(channels) % 2 != 0: + raise ValueError("the elements of `channels` should be specified in pairs.") if isinstance(kernel_size, Sequence) and len(kernel_size) != spatial_dims: raise ValueError("the length of `kernel_size` should equal to `dimensions`.") if isinstance(up_kernel_size, Sequence) and len(up_kernel_size) != spatial_dims: @@ -119,14 +125,17 @@ def __init__( self.channels = channels self.kernel_size = kernel_size self.up_kernel_size = up_kernel_size - self.act = ("leakyrelu", {"negative_slope": 0.2, "inplace": True}) if act.upper() == "LEAKYRELU" else act + self.act = ( + ("leakyrelu", {"negative_slope": 0.2, "inplace": True}) + if isinstance(act, str) and act.upper() == "LEAKYRELU" + else act + ) self.norm = norm self.dropout = dropout self.bias = bias self.adn_ordering = adn_ordering # VoxelMorph specific args - self.in_shape = in_shape self.unet_out_channels = unet_out_channels self.half_res = half_res self.use_maxpool = use_maxpool @@ -135,7 +144,7 @@ def __init__( self.final_conv_channels = final_conv_channels self.final_conv_act = ( ("leakyrelu", {"negative_slope": 0.2, "inplace": True}) - if final_conv_act.upper() == "LEAKYRELU" + if isinstance(final_conv_act, str) and final_conv_act.upper() == "LEAKYRELU" else final_conv_act ) @@ -178,7 +187,7 @@ def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Modul Args: inc: number of input channels, should be the same as `unet_out_channels`. - outc: number of output channels, should be 3 for 3D volume registration. + outc: number of output channels, should be the same as `spatial_dims`. channels: sequence of channels for each convolution layer. Note: there is no activation after the last convolution layer as per the original implementation. @@ -220,10 +229,11 @@ def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Modul return mod - self.final_conv = _create_final_conv(unet_out_channels, 3, self.final_conv_channels) + self.final_conv = _create_final_conv(unet_out_channels, self.dimensions, self.final_conv_channels) # create helpers - self.dvf2ddf = DVF2DDF(num_steps=self.int_steps, mode="bilinear", padding_mode="zeros") + if self.diffeomorphic: + self.dvf2ddf = DVF2DDF(num_steps=self.int_steps, mode="bilinear", padding_mode="zeros") self.warp = Warp(mode="bilinear", padding_mode="zeros") def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: @@ -255,7 +265,7 @@ def _get_down_layer(self, in_channels: int, out_channels: int, is_top: bool) -> mod: Convolution | nn.Sequential - strides = 1 if self.use_maxpool else 2 + strides = 1 if self.use_maxpool or is_top else 2 mod = Convolution( self.dimensions, @@ -271,7 +281,11 @@ def _get_down_layer(self, in_channels: int, out_channels: int, is_top: bool) -> ) if self.use_maxpool and not is_top: - mod = nn.Sequential(nn.MaxPool3d(kernel_size=2, stride=2), mod) + mod = ( + nn.Sequential(nn.MaxPool3d(kernel_size=2, stride=2), mod) + if self.dimensions == 3 + else nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), mod) + ) return mod @@ -348,7 +362,7 @@ def _get_up_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn return mod - def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> (torch.Tensor, torch.Tensor): + def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x = self.unet(torch.cat([moving, fixed], dim=1)) x = self.final_conv(x) diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py new file mode 100644 index 0000000000..ea71f04fec --- /dev/null +++ b/tests/test_voxelmorph.py @@ -0,0 +1,310 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import VoxelMorph +from tests.utils import test_script_save + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASE_0 = [ # single channel 3D, batch 1, non-diffeomorphic + # i.e., VoxelMorph as it is in the original paper + # https://arxiv.org/pdf/1809.05231.pdf + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "int_steps": 0, + }, + ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), + ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), +] + +TEST_CASE_1 = [ # single channel 3D, batch 1, diffeomorphic (default) + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), + ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), +] + +TEST_CASE_2 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "int_steps": 7, + "half_res": True, + }, + ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), + ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), +] + +TEST_CASE_3 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, + # using strided convolution for downsampling instead of maxpooling + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "int_steps": 7, + "half_res": True, + "use_maxpool": False, + }, + ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), + ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), +] + +TEST_CASE_4 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, + # using strided convolution for downsampling instead of maxpooling, + # explicitly specify leakyrelu with a different negative slope for final convolutions + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "final_conv_act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), + "int_steps": 7, + "half_res": True, + "use_maxpool": False, + }, + ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), + ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), +] + +TEST_CASE_5 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, + # using strided convolution for downsampling instead of maxpooling, + # explicitly specify leakyrelu with a different negative slope for both unet and final convolutions. + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "final_conv_act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), + "act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), + "int_steps": 7, + "half_res": True, + "use_maxpool": False, + }, + ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), + ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), +] + +TEST_CASE_6 = [ # 2-channel 3D, batch 1, diffeomorphic + # i.e., possible use case where the input contains both modalities (e.g., T1 and T2) + { + "spatial_dims": 3, + "in_channels": 4, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + ((1, 2, 160, 192, 224), (1, 2, 160, 192, 224)), + ((1, 2, 160, 192, 224), (1, 3, 160, 192, 224)), +] + +TEST_CASE_7 = [ # single channel 3D, batch 2, diffeomorphic + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + ((2, 1, 160, 192, 224), (2, 1, 160, 192, 224)), + ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), +] + +TEST_CASE_8 = [ # single channel 2D, batch 1 + { + "spatial_dims": 2, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + ((1, 1, 160, 192), (1, 1, 160, 192)), + ((1, 1, 160, 192), (1, 2, 160, 192)), +] + +TEST_CASE_9 = [ # single channel 3D, batch 2, diffeomorphic, + # one additional level in the UNet with 32 channels in both down and up branch. + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + ((2, 1, 160, 192, 224), (2, 1, 160, 192, 224)), + ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), +] + +CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, +] + +TEST_CASE_10 = [ # single channel 3D, batch 2, diffeomorphic, + # one additional level in the UNet with 32 channels in both down and up branch. + # and removed one of the two final convolution blocks. + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32, 32, 32), + "final_conv_channels": (16,), + }, + ((2, 1, 160, 192, 224), (2, 1, 160, 192, 224)), + ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), +] + +TEST_CASE_11 = [ # single channel 3D, batch 2, diffeomorphic, + # only one level in the UNet + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32), + "final_conv_channels": (16, 16), + }, + ((2, 1, 160, 192, 224), (2, 1, 160, 192, 224)), + ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), +] + +CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, +] + +ILL_CASE_0 = [ # spatial_dims = 1 + { + "spatial_dims": 1, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + } +] + +ILL_CASE_1 = [ # in_channels = 3 (not divisible by 2) + { + "spatial_dims": 3, + "in_channels": 3, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + } +] + +ILL_CASE_2 = [ # len(channels) = 0 + {"spatial_dims": 3, "in_channels": 2, "unet_out_channels": 32, "channels": (), "final_conv_channels": (16, 16)} +] + +ILL_CASE_3 = [ # channels not in pairs + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + } +] + +ILL_CASE_4 = [ # len(kernel_size) = 3, spatial_dims = 2 + { + "spatial_dims": 2, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "kernel_size": (3, 3, 3), + } +] + +ILL_CASE_5 = [ # len(up_kernel_size) = 2, spatial_dims = 3 + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "up_kernel_size": (3, 3), + } +] + +ILL_CASES = [ILL_CASE_0, ILL_CASE_1, ILL_CASE_2, ILL_CASE_3, ILL_CASE_4, ILL_CASE_5] + + +class TestUNET(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape): + net = VoxelMorph(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape[0]).to(device), torch.randn(input_shape[1]).to(device)) + self.assertEqual(result[0].shape, expected_shape[0]) + self.assertEqual(result[1].shape, expected_shape[1]) + + def test_script(self): + net = VoxelMorph( + spatial_dims=3, + in_channels=2, + unet_out_channels=32, + channels=(16, 32, 32, 32, 32, 32), + final_conv_channels=(16, 16), + ) + test_data = torch.randn(1, 1, 160, 192, 224) + test_script_save(net, test_data) + + @parameterized.expand(ILL_CASES) + def test_ill_input_hyper_params(self, input_param): + with self.assertRaises(ValueError): + _ = VoxelMorph(**input_param) + + +if __name__ == "__main__": + unittest.main() From 178c9bed021eabf8076c82931c1ee0c92216143b Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 1 Nov 2023 18:56:51 -0400 Subject: [PATCH 03/11] Fixed error caused by jit Signed-off-by: kaibo --- monai/networks/nets/voxelmorph.py | 10 +++++----- tests/test_voxelmorph.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index d707af746f..73986ce557 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -179,8 +179,6 @@ def _create_block(inc: int, outc: int, channels: Sequence[int], is_top: bool) -> return self._get_connection_block(down, up, subblock) - self.unet = _create_block(in_channels, unet_out_channels, self.channels, is_top=True) - def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Module: """ Builds the final convolution blocks. @@ -229,7 +227,10 @@ def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Modul return mod - self.final_conv = _create_final_conv(unet_out_channels, self.dimensions, self.final_conv_channels) + self.net = nn.Sequential( + _create_block(in_channels, unet_out_channels, self.channels, is_top=True), + _create_final_conv(unet_out_channels, self.dimensions, self.final_conv_channels) + ) # create helpers if self.diffeomorphic: @@ -363,8 +364,7 @@ def _get_up_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn return mod def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - x = self.unet(torch.cat([moving, fixed], dim=1)) - x = self.final_conv(x) + x = self.net(torch.cat([moving, fixed], dim=1)) if self.half_res: x = F.interpolate(x, scale_factor=0.5, mode="trilinear", align_corners=True) * 2.0 diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index ea71f04fec..8b0bfdaa48 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -242,7 +242,13 @@ ] ILL_CASE_2 = [ # len(channels) = 0 - {"spatial_dims": 3, "in_channels": 2, "unet_out_channels": 32, "channels": (), "final_conv_channels": (16, 16)} + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (), + "final_conv_channels": (16, 16), + } ] ILL_CASE_3 = [ # channels not in pairs @@ -280,7 +286,7 @@ ILL_CASES = [ILL_CASE_0, ILL_CASE_1, ILL_CASE_2, ILL_CASE_3, ILL_CASE_4, ILL_CASE_5] -class TestUNET(unittest.TestCase): +class TestVOXELMORPH(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): net = VoxelMorph(**input_param).to(device) @@ -296,8 +302,8 @@ def test_script(self): unet_out_channels=32, channels=(16, 32, 32, 32, 32, 32), final_conv_channels=(16, 16), - ) - test_data = torch.randn(1, 1, 160, 192, 224) + ).net + test_data = torch.randn(1, 2, 160, 192, 224) test_script_save(net, test_data) @parameterized.expand(ILL_CASES) From 4f2faaf8b0ede1f8e6eb6c06c57d9835efb629a1 Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 1 Nov 2023 19:33:22 -0400 Subject: [PATCH 04/11] Fixed code format issue Signed-off-by: kaibo --- monai/networks/nets/voxelmorph.py | 2 +- tests/test_voxelmorph.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index 73986ce557..26bba35aeb 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -229,7 +229,7 @@ def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Modul self.net = nn.Sequential( _create_block(in_channels, unet_out_channels, self.channels, is_top=True), - _create_final_conv(unet_out_channels, self.dimensions, self.final_conv_channels) + _create_final_conv(unet_out_channels, self.dimensions, self.final_conv_channels), ) # create helpers diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index 8b0bfdaa48..68e8b446dc 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -242,13 +242,7 @@ ] ILL_CASE_2 = [ # len(channels) = 0 - { - "spatial_dims": 3, - "in_channels": 2, - "unet_out_channels": 32, - "channels": (), - "final_conv_channels": (16, 16), - } + {"spatial_dims": 3, "in_channels": 2, "unet_out_channels": 32, "channels": (), "final_conv_channels": (16, 16)} ] ILL_CASE_3 = [ # channels not in pairs From 7e1736e7c070998b440d50148600becbb0ceca4c Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 1 Nov 2023 20:23:28 -0400 Subject: [PATCH 05/11] Previous tests failed (probably) due to OOM. Removed testing cases for some big models configs and big inputs Signed-off-by: kaibo --- tests/test_voxelmorph.py | 37 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index 68e8b446dc..ab186002e5 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -141,7 +141,7 @@ ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), ] -TEST_CASE_8 = [ # single channel 2D, batch 1 +TEST_CASE_8 = [ # single channel 2D, batch 2, diffeomorphic { "spatial_dims": 2, "in_channels": 2, @@ -149,8 +149,8 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((1, 1, 160, 192), (1, 1, 160, 192)), - ((1, 1, 160, 192), (1, 2, 160, 192)), + ((2, 1, 160, 192), (2, 1, 160, 192)), + ((2, 1, 160, 192), (2, 2, 160, 192)), ] TEST_CASE_9 = [ # single channel 3D, batch 2, diffeomorphic, @@ -166,19 +166,6 @@ ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), ] -CASES = [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - TEST_CASE_9, -] - TEST_CASE_10 = [ # single channel 3D, batch 2, diffeomorphic, # one additional level in the UNet with 32 channels in both down and up branch. # and removed one of the two final convolution blocks. @@ -193,7 +180,7 @@ ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), ] -TEST_CASE_11 = [ # single channel 3D, batch 2, diffeomorphic, +TEST_CASE_11 = [ # single channel 3D, batch 1, diffeomorphic, # only one level in the UNet { "spatial_dims": 3, @@ -202,8 +189,8 @@ "channels": (16, 32), "final_conv_channels": (16, 16), }, - ((2, 1, 160, 192, 224), (2, 1, 160, 192, 224)), - ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), + ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), + ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), ] CASES = [ @@ -214,10 +201,10 @@ TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, - TEST_CASE_7, + # TEST_CASE_7, TEST_CASE_8, - TEST_CASE_9, - TEST_CASE_10, + # TEST_CASE_9, + # TEST_CASE_10, TEST_CASE_11, ] @@ -250,7 +237,7 @@ "spatial_dims": 3, "in_channels": 2, "unet_out_channels": 32, - "channels": (16, 32, 32, 32, 32, 32, 32), + "channels": (16, 32, 32, 32, 32), "final_conv_channels": (16, 16), } ] @@ -260,7 +247,7 @@ "spatial_dims": 2, "in_channels": 2, "unet_out_channels": 32, - "channels": (16, 32, 32, 32, 32, 32, 32), + "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), "kernel_size": (3, 3, 3), } @@ -271,7 +258,7 @@ "spatial_dims": 3, "in_channels": 2, "unet_out_channels": 32, - "channels": (16, 32, 32, 32, 32, 32, 32), + "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), "up_kernel_size": (3, 3), } From 0ef81ac7624120cbf24bd6b9a4b0429eac338ce3 Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 1 Nov 2023 21:02:49 -0400 Subject: [PATCH 06/11] Previous tests indeed failed due to oom, changed the network in test_script to even smaller to avoid process end with code 139 Signed-off-by: kaibo --- tests/test_voxelmorph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index ab186002e5..830ff79edc 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -278,13 +278,13 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_script(self): net = VoxelMorph( - spatial_dims=3, + spatial_dims=2, in_channels=2, unet_out_channels=32, channels=(16, 32, 32, 32, 32, 32), final_conv_channels=(16, 16), ).net - test_data = torch.randn(1, 2, 160, 192, 224) + test_data = torch.randn(1, 2, 160, 192) test_script_save(net, test_data) @parameterized.expand(ILL_CASES) From 38ccf33945308e27b7593f0d718cb0850fce92b2 Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 1 Nov 2023 22:50:19 -0400 Subject: [PATCH 07/11] Updated docs and added more explanations in the code Signed-off-by: kaibo --- docs/source/networks.rst | 6 ++++++ monai/networks/nets/voxelmorph.py | 23 +++++++++++++++++------ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 7b7888732f..00271f2922 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -710,6 +710,12 @@ Nets .. autoclass:: Quicknat :members: +`VoxelMorph` +~~~~~~~~~~~~ +.. autoclass:: VoxelMorph + :members: + +.. autoclass:: voxelmorph Utilities --------- diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index 26bba35aeb..5effc2e70d 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -30,15 +30,26 @@ @alias("voxelmorph") class VoxelMorph(nn.Module): """ - Implementation of VoxelMorph network as described in https://arxiv.org/pdf/1809.05231.pdf. + VoxelMorph network for medical image registration as described in https://arxiv.org/pdf/1809.05231.pdf. + For more details, please refer to VoxelMorph: A Learning Framework for Deformable Medical Image Registration + Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. - Overview. A pair of images (moving and fixed) are concatenated along the channel dimension and passed through + A pair of images (moving and fixed) are concatenated along the channel dimension and passed through a UNet. The output of the UNet is then passed through a series of convolution blocks to produce the final prediction of the displacement field (DDF) in the non-diffeomorphic variant (i.e. when `int_steps` is set to 0) or the stationary velocity field (DVF) in the diffeomorphic variant (i.e. when `int_steps` is set to a positive integer). The DVF is then converted to a DDF using the `DVF2DDF` module. Finally, the DDF is used to warp the moving image to the fixed image using the `Warp` module. Optionally, the integration from DVF to DDF can be performed on reduced - resolution by specifying `half_res` to be True. + resolution by specifying `half_res` to be True, in which case the output DVF from the UNet is first linearly + interpolated to half resolution before being passed to the `DVF2DDF` module. The output DDF is then linearly + interpolated again back to full resolution before being used in the `Warp` module. + + In the original implementation, downsample is achieved through maxpooling, here one has the option to use either + maxpooling or strided convolution for downsampling. The default is to use maxpooling as it is consistent with the + original implementation. Note that for upsampling, the authors of VoxelMorph used nearest neighbor interpolation + instead of transposed convolution. In this implementation, only nearest neighbor interpolation is supported in order + to be consistent with the original implementation. Args: spatial_dims: number of spatial dimensions. @@ -62,7 +73,7 @@ class VoxelMorph(nn.Module): But one can optionally use strided convolution instead (i.e. set `use_maxpool` to False). adn_ordering: ordering of activation, dropout, and normalization. Defaults to "NDA". - Examples:: + Example:: from monai.networks.nets import VoxelMorph @@ -80,8 +91,8 @@ class VoxelMorph(nn.Module): ) # A forward pass through the network would look something like this - moving = torch.randn(1, 2, 160, 192, 224) - fixed = torch.randn(1, 2, 160, 192, 224) + moving = torch.randn(1, 1, 160, 192, 224) + fixed = torch.randn(1, 1, 160, 192, 224) warped, ddf = net(moving, fixed) """ From cb91654b360327916384feb0e63950adbafda455 Mon Sep 17 00:00:00 2001 From: kaibo Date: Thu, 2 Nov 2023 09:50:15 -0400 Subject: [PATCH 08/11] Renamed int_steps to integration_steps, explained how integration works more carefully, and changed all 3d test cases to have smaller sizes Signed-off-by: kaibo --- monai/networks/nets/voxelmorph.py | 36 ++++++++-------- tests/test_voxelmorph.py | 68 +++++++++++++++---------------- 2 files changed, 53 insertions(+), 51 deletions(-) diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index 5effc2e70d..4ba301b198 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -30,20 +30,21 @@ @alias("voxelmorph") class VoxelMorph(nn.Module): """ - VoxelMorph network for medical image registration as described in https://arxiv.org/pdf/1809.05231.pdf. + A re-implementation of VoxelMorph network for medical image registration as described in https://arxiv.org/pdf/1809.05231.pdf. For more details, please refer to VoxelMorph: A Learning Framework for Deformable Medical Image Registration Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. - A pair of images (moving and fixed) are concatenated along the channel dimension and passed through - a UNet. The output of the UNet is then passed through a series of convolution blocks to produce the final prediction - of the displacement field (DDF) in the non-diffeomorphic variant (i.e. when `int_steps` is set to 0) or the - stationary velocity field (DVF) in the diffeomorphic variant (i.e. when `int_steps` is set to a positive integer). - The DVF is then converted to a DDF using the `DVF2DDF` module. Finally, the DDF is used to warp the moving image - to the fixed image using the `Warp` module. Optionally, the integration from DVF to DDF can be performed on reduced - resolution by specifying `half_res` to be True, in which case the output DVF from the UNet is first linearly - interpolated to half resolution before being passed to the `DVF2DDF` module. The output DDF is then linearly - interpolated again back to full resolution before being used in the `Warp` module. + A pair of images (moving and fixed) are concatenated along the channel dimension and passed through a UNet. The + output of the UNet is then passed through a series of convolution blocks to produce the final prediction + of the displacement field (DDF) in the non-diffeomorphic variant (i.e. when `integration_steps` is set to 0) or the + stationary velocity field (DVF) in the diffeomorphic variant (i.e. when `integration_steps` is set to a positive + integer). The DVF is then integrated using a scaling-and-squaring approach via the `DVF2DDF` module to produce the + DDF. Finally, the DDF is used to warp the moving image to the fixed image using the `Warp` module. Optionally, the + integration from DVF to DDF can be performed on reduced resolution by specifying `half_res` to be True, in which + case the output DVF from the UNet is first linearly interpolated to half resolution before being passed to the + `DVF2DDF` module. The output DDF is then linearly interpolated again back to full resolution before being used in + the `Warp` module. In the original implementation, downsample is achieved through maxpooling, here one has the option to use either maxpooling or strided convolution for downsampling. The default is to use maxpooling as it is consistent with the @@ -58,9 +59,10 @@ class VoxelMorph(nn.Module): channels: number of channels in each layer of the UNet. See the following example for more details. final_conv_channels: number of channels in each layer of the final convolution block. final_conv_act: activation type for the final convolution block. Defaults to LeakyReLU. - Since VoxelMorph was originally implemented in tensorflow where the default negative slope for LeakyReLU - was 0.2, we use the same default value here. - int_steps: number of integration steps. Defaults to 7. If set to 0, the network will be non-diffeomorphic. + Since VoxelMorph was originally implemented in tensorflow where the default negative slope for + LeakyReLU was 0.2, we use the same default value here. + integration_steps: number of integration steps used for obtaining DDF from DVF via scaling-and-squaring. + Defaults to 7. If set to 0, the network will be non-diffeomorphic. kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3. up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3. act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2. @@ -104,7 +106,7 @@ def __init__( channels: Sequence[int], final_conv_channels: Sequence[int], final_conv_act: tuple | str | None = "LEAKYRELU", - int_steps: int = 7, + integration_steps: int = 7, kernel_size: Sequence[int] | int = 3, up_kernel_size: Sequence[int] | int = 3, act: tuple | str = "LEAKYRELU", @@ -160,8 +162,8 @@ def __init__( ) # integration args - self.int_steps = int_steps - self.diffeomorphic = True if self.int_steps > 0 else False + self.integration_steps = integration_steps + self.diffeomorphic = True if self.integration_steps > 0 else False def _create_block(inc: int, outc: int, channels: Sequence[int], is_top: bool) -> nn.Module: """ @@ -245,7 +247,7 @@ def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Modul # create helpers if self.diffeomorphic: - self.dvf2ddf = DVF2DDF(num_steps=self.int_steps, mode="bilinear", padding_mode="zeros") + self.dvf2ddf = DVF2DDF(num_steps=self.integration_steps, mode="bilinear", padding_mode="zeros") self.warp = Warp(mode="bilinear", padding_mode="zeros") def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index 830ff79edc..b88f25ab75 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -31,10 +31,10 @@ "unet_out_channels": 32, "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), - "int_steps": 0, + "integration_steps": 0, }, - ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), - ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), + ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), + ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), ] TEST_CASE_1 = [ # single channel 3D, batch 1, diffeomorphic (default) @@ -45,8 +45,8 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), - ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), + ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), + ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), ] TEST_CASE_2 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution @@ -56,11 +56,11 @@ "unet_out_channels": 32, "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), - "int_steps": 7, + "integration_steps": 7, "half_res": True, }, - ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), - ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), + ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), + ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), ] TEST_CASE_3 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, @@ -71,12 +71,12 @@ "unet_out_channels": 32, "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), - "int_steps": 7, + "integration_steps": 7, "half_res": True, "use_maxpool": False, }, - ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), - ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), + ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), + ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), ] TEST_CASE_4 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, @@ -89,12 +89,12 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), "final_conv_act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), - "int_steps": 7, + "integration_steps": 7, "half_res": True, "use_maxpool": False, }, - ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), - ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), + ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), + ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), ] TEST_CASE_5 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, @@ -108,12 +108,12 @@ "final_conv_channels": (16, 16), "final_conv_act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), "act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), - "int_steps": 7, + "integration_steps": 7, "half_res": True, "use_maxpool": False, }, - ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), - ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), + ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), + ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), ] TEST_CASE_6 = [ # 2-channel 3D, batch 1, diffeomorphic @@ -125,8 +125,8 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((1, 2, 160, 192, 224), (1, 2, 160, 192, 224)), - ((1, 2, 160, 192, 224), (1, 3, 160, 192, 224)), + ((1, 2, 96, 96, 48), (1, 2, 96, 96, 48)), + ((1, 2, 96, 96, 48), (1, 3, 96, 96, 48)), ] TEST_CASE_7 = [ # single channel 3D, batch 2, diffeomorphic @@ -137,8 +137,8 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((2, 1, 160, 192, 224), (2, 1, 160, 192, 224)), - ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), + ((2, 1, 96, 96, 48), (2, 1, 96, 96, 48)), + ((2, 1, 96, 96, 48), (2, 3, 96, 96, 48)), ] TEST_CASE_8 = [ # single channel 2D, batch 2, diffeomorphic @@ -149,8 +149,8 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((2, 1, 160, 192), (2, 1, 160, 192)), - ((2, 1, 160, 192), (2, 2, 160, 192)), + ((2, 1, 96, 96), (2, 1, 96, 96)), + ((2, 1, 96, 96), (2, 2, 96, 96)), ] TEST_CASE_9 = [ # single channel 3D, batch 2, diffeomorphic, @@ -162,8 +162,8 @@ "channels": (16, 32, 32, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((2, 1, 160, 192, 224), (2, 1, 160, 192, 224)), - ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), + ((2, 1, 96, 96, 48), (2, 1, 96, 96, 48)), + ((2, 1, 96, 96, 48), (2, 3, 96, 96, 48)), ] TEST_CASE_10 = [ # single channel 3D, batch 2, diffeomorphic, @@ -176,8 +176,8 @@ "channels": (16, 32, 32, 32, 32, 32, 32, 32), "final_conv_channels": (16,), }, - ((2, 1, 160, 192, 224), (2, 1, 160, 192, 224)), - ((2, 1, 160, 192, 224), (2, 3, 160, 192, 224)), + ((2, 1, 96, 96, 48), (2, 1, 96, 96, 48)), + ((2, 1, 96, 96, 48), (2, 3, 96, 96, 48)), ] TEST_CASE_11 = [ # single channel 3D, batch 1, diffeomorphic, @@ -189,8 +189,8 @@ "channels": (16, 32), "final_conv_channels": (16, 16), }, - ((1, 1, 160, 192, 224), (1, 1, 160, 192, 224)), - ((1, 1, 160, 192, 224), (1, 3, 160, 192, 224)), + ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), + ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), ] CASES = [ @@ -201,10 +201,10 @@ TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, - # TEST_CASE_7, + TEST_CASE_7, TEST_CASE_8, - # TEST_CASE_9, - # TEST_CASE_10, + TEST_CASE_9, + TEST_CASE_10, TEST_CASE_11, ] @@ -278,13 +278,13 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_script(self): net = VoxelMorph( - spatial_dims=2, + spatial_dims=3, in_channels=2, unet_out_channels=32, channels=(16, 32, 32, 32, 32, 32), final_conv_channels=(16, 16), ).net - test_data = torch.randn(1, 2, 160, 192) + test_data = torch.randn(1, 2, 96, 96, 48) test_script_save(net, test_data) @parameterized.expand(ILL_CASES) From 37d4b453fbe9c9569177e1ecc2d5b7195eedf8f1 Mon Sep 17 00:00:00 2001 From: kaibo Date: Thu, 2 Nov 2023 16:47:03 -0400 Subject: [PATCH 09/11] Decoupled voxelmorph (the framework) from voxelmorph (the unet subnetwork), rewrote some of the docstrings, and modified the unit tests accordingly. Specifically, in the voxelmorph framework class, made a few assertions to make sure data line up well with the specified backbone, and added corresponding illegal cases in unit test. Signed-off-by: kaibo --- docs/source/networks.rst | 6 + monai/networks/nets/__init__.py | 2 +- monai/networks/nets/voxelmorph.py | 204 +++++++++++++++++++++--------- tests/test_voxelmorph.py | 138 +++++++++----------- 4 files changed, 213 insertions(+), 137 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 00271f2922..8eada7933f 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -712,6 +712,12 @@ Nets `VoxelMorph` ~~~~~~~~~~~~ +.. autoclass:: VoxelMorphUNet + :members: + +.. autoclass:: voxelmorphunet + :members: + .. autoclass:: VoxelMorph :members: diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 8064b815a3..9247aaee85 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -111,4 +111,4 @@ from .vit import ViT from .vitautoenc import ViTAutoEnc from .vnet import VNet -from .voxelmorph import VoxelMorph +from .voxelmorph import VoxelMorph, VoxelMorphUNet diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index 4ba301b198..86ec41af82 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -23,28 +23,18 @@ from monai.networks.layers.simplelayers import SkipConnection from monai.utils import alias, export -__all__ = ["VoxelMorph", "voxelmorph"] +__all__ = ["VoxelMorphUNet", "voxelmorphunet", "VoxelMorph", "voxelmorph"] @export("monai.networks.nets") -@alias("voxelmorph") -class VoxelMorph(nn.Module): +@alias("voxelmorphunet") +class VoxelMorphUNet(nn.Module): """ - A re-implementation of VoxelMorph network for medical image registration as described in https://arxiv.org/pdf/1809.05231.pdf. - For more details, please refer to VoxelMorph: A Learning Framework for Deformable Medical Image Registration - Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca - IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + The backbone network used in VoxelMorph. See :py:class:`monai.networks.nets.VoxelMorph` for more details. - A pair of images (moving and fixed) are concatenated along the channel dimension and passed through a UNet. The - output of the UNet is then passed through a series of convolution blocks to produce the final prediction - of the displacement field (DDF) in the non-diffeomorphic variant (i.e. when `integration_steps` is set to 0) or the - stationary velocity field (DVF) in the diffeomorphic variant (i.e. when `integration_steps` is set to a positive - integer). The DVF is then integrated using a scaling-and-squaring approach via the `DVF2DDF` module to produce the - DDF. Finally, the DDF is used to warp the moving image to the fixed image using the `Warp` module. Optionally, the - integration from DVF to DDF can be performed on reduced resolution by specifying `half_res` to be True, in which - case the output DVF from the UNet is first linearly interpolated to half resolution before being passed to the - `DVF2DDF` module. The output DDF is then linearly interpolated again back to full resolution before being used in - the `Warp` module. + A concatenated pair of images (moving and fixed) is first passed through a UNet. The output of the UNet is then + passed through a series of convolution blocks to produce the final prediction of the displacement field (DDF) or the + stationary velocity field (DVF). In the original implementation, downsample is achieved through maxpooling, here one has the option to use either maxpooling or strided convolution for downsampling. The default is to use maxpooling as it is consistent with the @@ -52,6 +42,10 @@ class VoxelMorph(nn.Module): instead of transposed convolution. In this implementation, only nearest neighbor interpolation is supported in order to be consistent with the original implementation. + An instance of this class can be used as a backbone network for constructing a VoxelMorph network. See the + documentation of :py:class:`monai.networks.nets.VoxelMorph` for more details and an example on how to construct a + VoxelMorph network. + Args: spatial_dims: number of spatial dimensions. in_channels: number of channels in the input volume after concatenation of moving and fixed images. @@ -61,41 +55,16 @@ class VoxelMorph(nn.Module): final_conv_act: activation type for the final convolution block. Defaults to LeakyReLU. Since VoxelMorph was originally implemented in tensorflow where the default negative slope for LeakyReLU was 0.2, we use the same default value here. - integration_steps: number of integration steps used for obtaining DDF from DVF via scaling-and-squaring. - Defaults to 7. If set to 0, the network will be non-diffeomorphic. kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3. up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3. act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2. norm: feature normalization type and arguments for all convolution layers in the UNet. Defaults to None. dropout: dropout ratio for all convolution layers in the UNet. Defaults to 0.0 (no dropout). bias: whether to use bias in all convolution layers in the UNet. Defaults to True. - half_res: whether to perform integration on half resolution. Defaults to False. use_maxpool: whether to use maxpooling in the downsampling path of the UNet. Defaults to True. Using maxpooling is the consistent with the original implementation of VoxelMorph. But one can optionally use strided convolution instead (i.e. set `use_maxpool` to False). adn_ordering: ordering of activation, dropout, and normalization. Defaults to "NDA". - - Example:: - - from monai.networks.nets import VoxelMorph - - # VoxelMorph network as it is in the original paper https://arxiv.org/pdf/1809.05231.pdf - net = VoxelMorph( - spatial_dims=3, - in_channels=2, - unet_out_channels=32, - channels=(16, 32, 32, 32, 32, 32), # this indicates the down block at the top takes 16 channels as - # input, the corresponding up block at the top produces 32 - # channels as output, the second down block takes 32 channels as - # input, and the corresponding up block at the same level - # produces 32 channels as output, etc. - final_conv_channels=(16, 16) - ) - - # A forward pass through the network would look something like this - moving = torch.randn(1, 1, 160, 192, 224) - fixed = torch.randn(1, 1, 160, 192, 224) - warped, ddf = net(moving, fixed) """ def __init__( @@ -106,14 +75,12 @@ def __init__( channels: Sequence[int], final_conv_channels: Sequence[int], final_conv_act: tuple | str | None = "LEAKYRELU", - integration_steps: int = 7, kernel_size: Sequence[int] | int = 3, up_kernel_size: Sequence[int] | int = 3, act: tuple | str = "LEAKYRELU", norm: tuple | str | None = None, dropout: float = 0.0, bias: bool = True, - half_res: bool = False, use_maxpool: bool = True, adn_ordering: str = "NDA", ) -> None: @@ -135,6 +102,7 @@ def __init__( # UNet args self.dimensions = spatial_dims self.in_channels = in_channels + self.unet_out_channels = unet_out_channels self.channels = channels self.kernel_size = kernel_size self.up_kernel_size = up_kernel_size @@ -146,12 +114,8 @@ def __init__( self.norm = norm self.dropout = dropout self.bias = bias - self.adn_ordering = adn_ordering - - # VoxelMorph specific args - self.unet_out_channels = unet_out_channels - self.half_res = half_res self.use_maxpool = use_maxpool + self.adn_ordering = adn_ordering # final convolutions args self.final_conv_channels = final_conv_channels @@ -161,10 +125,6 @@ def __init__( else final_conv_act ) - # integration args - self.integration_steps = integration_steps - self.diffeomorphic = True if self.integration_steps > 0 else False - def _create_block(inc: int, outc: int, channels: Sequence[int], is_top: bool) -> nn.Module: """ Builds the UNet structure recursively. @@ -245,11 +205,6 @@ def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Modul _create_final_conv(unet_out_channels, self.dimensions, self.final_conv_channels), ) - # create helpers - if self.diffeomorphic: - self.dvf2ddf = DVF2DDF(num_steps=self.integration_steps, mode="bilinear", padding_mode="zeros") - self.warp = Warp(mode="bilinear", padding_mode="zeros") - def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: """ Returns the block object defining a layer of the UNet structure including the implementation of the skip @@ -376,8 +331,139 @@ def _get_up_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn return mod + def forward(self, concatenated_pairs: torch.Tensor) -> torch.Tensor: + x = self.net(concatenated_pairs) + assert isinstance(x, torch.Tensor) # won't pass mypy check without this line + return x + + +voxelmorphunet = VoxelMorphUNet + + +@export("monai.networks.nets") +@alias("voxelmorph") +class VoxelMorph(nn.Module): + """ + A re-implementation of VoxelMorph framework for medical image registration as described in + https://arxiv.org/pdf/1809.05231.pdf. For more details, please refer to VoxelMorph: A Learning Framework for + Deformable Medical Image Registration, Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + + This class is intended to be a general framework, based on which a deformable image registration + network can be built. Given a user-specified backbone network (e.g., UNet in the original VoxelMorph paper), this + class serves as a wrapper that concatenates the input pair of moving and fixed images, passes through the backbone + network, integrate the predicted stationary velocity field (DVF) from the backbone network to obtain the + displacement field (DDF), and, finally, warp the moving image using the DDF. + + To construct a VoxelMorph network, one need to first construct a backbone network + (e.g., a :py:class:`monai.networks.nets.VoxelMorphUNet`) and pass it to the constructor of + :py:class:`monai.networks.nets.VoxelMorph`. The backbone network should be able to take a pair of moving and fixed + images as input and produce a DVF (or DDF, details to be discussed later) as output. + + When `forward` is called, the input moving and fixed images are first concatenated along the channel dimension and + passed through the specified backbone network to produce the prediction of the displacement field (DDF) in the + non-diffeomorphic variant (i.e. when `integration_steps` is set to 0) or the stationary velocity field (DVF) in the + diffeomorphic variant (i.e. when `integration_steps` is set to a positive integer). The DVF is then integrated using + a scaling-and-squaring approach via a :py:class:`monai.networks.blocks.warp.DVF2DDF` module to produce the DDF. + Finally, the DDF is used to warp the moving image to the fixed image using a + :py:class:`monai.networks.blocks.warp.Warp` module. Optionally, the integration from DVF to DDF can be + performed on reduced resolution by specifying `half_res` to be True, in which case the output DVF from the backbone + network is first linearly interpolated to half resolution before integration. The output DDF is then linearly + interpolated again back to full resolution before being used to warp the moving image. + + Args: + spatial_dims: number of spatial dimensions. + backbone: a backbone network. + integration_steps: number of integration steps used for obtaining DDF from DVF via scaling-and-squaring. + Defaults to 7. If set to 0, the network will be non-diffeomorphic. + half_res: whether to perform integration on half resolution. Defaults to False. + + Example:: + + from monai.networks.nets import VoxelMorphUNet, VoxelMorph + + # The following example construct an instance of VoxelMorph that matches the original VoxelMorph paper + # https://arxiv.org/pdf/1809.05231.pdf + + # First, a backbone network is constructed. In this case, we use a VoxelMorphUNet as the backbone network. + backbone = VoxelMorphUNet( + spatial_dims=3, + in_channels=2, + unet_out_channels=32, + channels=(16, 32, 32, 32, 32, 32), # this indicates the down block at the top takes 16 channels as + # input, the corresponding up block at the top produces 32 + # channels as output, the second down block takes 32 channels as + # input, and the corresponding up block at the same level + # produces 32 channels as output, etc. + final_conv_channels=(16, 16) + ) + + # Then, a full VoxelMorph network is constructed using the specified backbone network. + net = VoxelMorph( + backbone=backbone, + integration_steps=7, + half_res=False + ) + + # A forward pass through the network would look something like this + moving = torch.randn(1, 1, 160, 192, 224) + fixed = torch.randn(1, 1, 160, 192, 224) + warped, ddf = net(moving, fixed) + """ + + def __init__( + self, + spatial_dims: int, + backbone: VoxelMorphUNet | nn.Module | None = None, + integration_steps: int = 7, + half_res: bool = False, + ) -> None: + super().__init__() + + # specified backbone network + self.backbone = ( + backbone + if backbone is not None + else VoxelMorphUNet( + spatial_dims=3, + in_channels=2, + unet_out_channels=32, + channels=(16, 32, 32, 32, 32, 32), + final_conv_channels=(16, 16), + ) + ) + + # helper attributes + self.spatial_dims = spatial_dims + self.half_res = half_res + self.integration_steps = integration_steps + self.diffeomorphic = True if self.integration_steps > 0 else False + + # create helpers + if self.diffeomorphic: + self.dvf2ddf = DVF2DDF(num_steps=self.integration_steps, mode="bilinear", padding_mode="zeros") + self.warp = Warp(mode="bilinear", padding_mode="zeros") + def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - x = self.net(torch.cat([moving, fixed], dim=1)) + if moving.shape != fixed.shape: + raise ValueError( + f"The spatial shape of the moving image should be the same as the spatial shape of the fixed image." + f" Got {moving.shape} and {fixed.shape} instead." + ) + + x = self.backbone(torch.cat([moving, fixed], dim=1)) + + if x.shape[1] != self.spatial_dims: + raise ValueError( + f"The number of channels in the output of the backbone network should be equal to the" + f" number of spatial dimensions. Got {x.shape[1]} channels instead." + ) + + if x.shape[2:] != moving.shape[2:]: + raise ValueError( + "The spatial shape of the output of the backbone network should be equal to the" + f" spatial shape of the input images. Got {x.shape[2:]} instead of {moving.shape[2:]}." + ) if self.half_res: x = F.interpolate(x, scale_factor=0.5, mode="trilinear", align_corners=True) * 2.0 diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index b88f25ab75..ce9a776032 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -17,53 +17,24 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import VoxelMorph +from monai.networks.nets import VoxelMorph, VoxelMorphUNet from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_0 = [ # single channel 3D, batch 1, non-diffeomorphic - # i.e., VoxelMorph as it is in the original paper - # https://arxiv.org/pdf/1809.05231.pdf +TEST_CASE_0 = [ # single channel 3D, batch 1, { "spatial_dims": 3, "in_channels": 2, "unet_out_channels": 32, "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), - "integration_steps": 0, }, - ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), - ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), ] -TEST_CASE_1 = [ # single channel 3D, batch 1, diffeomorphic (default) - { - "spatial_dims": 3, - "in_channels": 2, - "unet_out_channels": 32, - "channels": (16, 32, 32, 32, 32, 32), - "final_conv_channels": (16, 16), - }, - ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), - ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), -] - -TEST_CASE_2 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution - { - "spatial_dims": 3, - "in_channels": 2, - "unet_out_channels": 32, - "channels": (16, 32, 32, 32, 32, 32), - "final_conv_channels": (16, 16), - "integration_steps": 7, - "half_res": True, - }, - ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), - ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), -] - -TEST_CASE_3 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, +TEST_CASE_1 = [ # single channel 3D, batch 1, # using strided convolution for downsampling instead of maxpooling { "spatial_dims": 3, @@ -71,15 +42,13 @@ "unet_out_channels": 32, "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), - "integration_steps": 7, - "half_res": True, "use_maxpool": False, }, - ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), - ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), ] -TEST_CASE_4 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, +TEST_CASE_2 = [ # single channel 3D, batch 1, # using strided convolution for downsampling instead of maxpooling, # explicitly specify leakyrelu with a different negative slope for final convolutions { @@ -89,15 +58,13 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), "final_conv_act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), - "integration_steps": 7, - "half_res": True, "use_maxpool": False, }, - ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), - ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), ] -TEST_CASE_5 = [ # single channel 3D, batch 1, diffeomorphic, integration at half resolution, +TEST_CASE_3 = [ # single channel 3D, batch 1, # using strided convolution for downsampling instead of maxpooling, # explicitly specify leakyrelu with a different negative slope for both unet and final convolutions. { @@ -108,15 +75,13 @@ "final_conv_channels": (16, 16), "final_conv_act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), "act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), - "integration_steps": 7, - "half_res": True, "use_maxpool": False, }, - ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), - ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), ] -TEST_CASE_6 = [ # 2-channel 3D, batch 1, diffeomorphic +TEST_CASE_4 = [ # 2-channel 3D, batch 1, # i.e., possible use case where the input contains both modalities (e.g., T1 and T2) { "spatial_dims": 3, @@ -125,11 +90,11 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((1, 2, 96, 96, 48), (1, 2, 96, 96, 48)), - ((1, 2, 96, 96, 48), (1, 3, 96, 96, 48)), + (1, 4, 96, 96, 48), + (1, 3, 96, 96, 48), ] -TEST_CASE_7 = [ # single channel 3D, batch 2, diffeomorphic +TEST_CASE_5 = [ # single channel 3D, batch 2, { "spatial_dims": 3, "in_channels": 2, @@ -137,11 +102,11 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((2, 1, 96, 96, 48), (2, 1, 96, 96, 48)), - ((2, 1, 96, 96, 48), (2, 3, 96, 96, 48)), + (2, 2, 96, 96, 48), + (2, 3, 96, 96, 48), ] -TEST_CASE_8 = [ # single channel 2D, batch 2, diffeomorphic +TEST_CASE_6 = [ # single channel 2D, batch 2, { "spatial_dims": 2, "in_channels": 2, @@ -149,11 +114,11 @@ "channels": (16, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((2, 1, 96, 96), (2, 1, 96, 96)), - ((2, 1, 96, 96), (2, 2, 96, 96)), + (2, 2, 96, 96), + (2, 2, 96, 96), ] -TEST_CASE_9 = [ # single channel 3D, batch 2, diffeomorphic, +TEST_CASE_7 = [ # single channel 3D, batch 1, # one additional level in the UNet with 32 channels in both down and up branch. { "spatial_dims": 3, @@ -162,11 +127,11 @@ "channels": (16, 32, 32, 32, 32, 32, 32, 32), "final_conv_channels": (16, 16), }, - ((2, 1, 96, 96, 48), (2, 1, 96, 96, 48)), - ((2, 1, 96, 96, 48), (2, 3, 96, 96, 48)), + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), ] -TEST_CASE_10 = [ # single channel 3D, batch 2, diffeomorphic, +TEST_CASE_8 = [ # single channel 3D, batch 1, # one additional level in the UNet with 32 channels in both down and up branch. # and removed one of the two final convolution blocks. { @@ -176,11 +141,11 @@ "channels": (16, 32, 32, 32, 32, 32, 32, 32), "final_conv_channels": (16,), }, - ((2, 1, 96, 96, 48), (2, 1, 96, 96, 48)), - ((2, 1, 96, 96, 48), (2, 3, 96, 96, 48)), + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), ] -TEST_CASE_11 = [ # single channel 3D, batch 1, diffeomorphic, +TEST_CASE_9 = [ # single channel 3D, batch 1, # only one level in the UNet { "spatial_dims": 3, @@ -189,8 +154,8 @@ "channels": (16, 32), "final_conv_channels": (16, 16), }, - ((1, 1, 96, 96, 48), (1, 1, 96, 96, 48)), - ((1, 1, 96, 96, 48), (1, 3, 96, 96, 48)), + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), ] CASES = [ @@ -204,8 +169,6 @@ TEST_CASE_7, TEST_CASE_8, TEST_CASE_9, - TEST_CASE_10, - TEST_CASE_11, ] ILL_CASE_0 = [ # spatial_dims = 1 @@ -267,30 +230,51 @@ ILL_CASES = [ILL_CASE_0, ILL_CASE_1, ILL_CASE_2, ILL_CASE_3, ILL_CASE_4, ILL_CASE_5] +ILL_CASES_IN_SHAPE_0 = [ # moving and fixed image shape not match + {"spatial_dims": 3}, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +ILL_CASES_IN_SHAPE_1 = [ # spatial_dims = 2, ddf has 3 channels + {"spatial_dims": 2}, + (1, 1, 96, 96, 96), + (1, 1, 96, 96, 96), +] + +ILL_CASES_IN_SHAPE = [ILL_CASES_IN_SHAPE_0, ILL_CASES_IN_SHAPE_1] + + class TestVOXELMORPH(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): - net = VoxelMorph(**input_param).to(device) + net = VoxelMorphUNet(**input_param).to(device) with eval_mode(net): - result = net.forward(torch.randn(input_shape[0]).to(device), torch.randn(input_shape[1]).to(device)) - self.assertEqual(result[0].shape, expected_shape[0]) - self.assertEqual(result[1].shape, expected_shape[1]) + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) def test_script(self): - net = VoxelMorph( - spatial_dims=3, + net = VoxelMorphUNet( + spatial_dims=2, in_channels=2, unet_out_channels=32, channels=(16, 32, 32, 32, 32, 32), final_conv_channels=(16, 16), - ).net - test_data = torch.randn(1, 2, 96, 96, 48) + ) + test_data = torch.randn(1, 2, 96, 96) test_script_save(net, test_data) @parameterized.expand(ILL_CASES) def test_ill_input_hyper_params(self, input_param): with self.assertRaises(ValueError): - _ = VoxelMorph(**input_param) + _ = VoxelMorphUNet(**input_param) + + @parameterized.expand(ILL_CASES_IN_SHAPE) + def test_ill_input_shape(self, input_param, moving_shape, fixed_shape): + with self.assertRaises(ValueError): + net = VoxelMorph(**input_param).to(device) + with eval_mode(net): + _ = net.forward(torch.randn(moving_shape).to(device), torch.randn(fixed_shape).to(device)) if __name__ == "__main__": From 6cd223b8fb7f7b7d1cc09c2358b263091185ac7a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 3 Nov 2023 08:59:04 +0000 Subject: [PATCH 10/11] fixes docstring/type format Signed-off-by: Wenqi Li --- monai/networks/nets/voxelmorph.py | 151 +++++++++++++++--------------- 1 file changed, 76 insertions(+), 75 deletions(-) diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index 86ec41af82..0496cfc8f8 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -47,24 +47,24 @@ class VoxelMorphUNet(nn.Module): VoxelMorph network. Args: - spatial_dims: number of spatial dimensions. - in_channels: number of channels in the input volume after concatenation of moving and fixed images. - unet_out_channels: number of channels in the output of the UNet. - channels: number of channels in each layer of the UNet. See the following example for more details. - final_conv_channels: number of channels in each layer of the final convolution block. - final_conv_act: activation type for the final convolution block. Defaults to LeakyReLU. - Since VoxelMorph was originally implemented in tensorflow where the default negative slope for - LeakyReLU was 0.2, we use the same default value here. - kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3. - up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3. - act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2. - norm: feature normalization type and arguments for all convolution layers in the UNet. Defaults to None. - dropout: dropout ratio for all convolution layers in the UNet. Defaults to 0.0 (no dropout). - bias: whether to use bias in all convolution layers in the UNet. Defaults to True. - use_maxpool: whether to use maxpooling in the downsampling path of the UNet. Defaults to True. - Using maxpooling is the consistent with the original implementation of VoxelMorph. - But one can optionally use strided convolution instead (i.e. set `use_maxpool` to False). - adn_ordering: ordering of activation, dropout, and normalization. Defaults to "NDA". + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input volume after concatenation of moving and fixed images. + unet_out_channels: number of channels in the output of the UNet. + channels: number of channels in each layer of the UNet. See the following example for more details. + final_conv_channels: number of channels in each layer of the final convolution block. + final_conv_act: activation type for the final convolution block. Defaults to LeakyReLU. + Since VoxelMorph was originally implemented in tensorflow where the default negative slope for + LeakyReLU was 0.2, we use the same default value here. + kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3. + up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3. + act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2. + norm: feature normalization type and arguments for all convolution layers in the UNet. Defaults to None. + dropout: dropout ratio for all convolution layers in the UNet. Defaults to 0.0 (no dropout). + bias: whether to use bias in all convolution layers in the UNet. Defaults to True. + use_maxpool: whether to use maxpooling in the downsampling path of the UNet. Defaults to True. + Using maxpooling is the consistent with the original implementation of VoxelMorph. + But one can optionally use strided convolution instead (i.e. set `use_maxpool` to False). + adn_ordering: ordering of activation, dropout, and normalization. Defaults to "NDA". """ def __init__( @@ -130,10 +130,10 @@ def _create_block(inc: int, outc: int, channels: Sequence[int], is_top: bool) -> Builds the UNet structure recursively. Args: - inc: number of input channels. - outc: number of output channels. - channels: sequence of channels for each pair of down and up layers. - is_top: True if this is the top block. + inc: number of input channels. + outc: number of output channels. + channels: sequence of channels for each pair of down and up layers. + is_top: True if this is the top block. """ next_c_in, next_c_out = channels[0:2] @@ -157,9 +157,9 @@ def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Modul Builds the final convolution blocks. Args: - inc: number of input channels, should be the same as `unet_out_channels`. - outc: number of output channels, should be the same as `spatial_dims`. - channels: sequence of channels for each convolution layer. + inc: number of input channels, should be the same as `unet_out_channels`. + outc: number of output channels, should be the same as `spatial_dims`. + channels: sequence of channels for each convolution layer. Note: there is no activation after the last convolution layer as per the original implementation. """ @@ -211,9 +211,10 @@ def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblo between encoding (down) and decoding (up) sides of the network. Args: - down_path: encoding half of the layer - up_path: decoding half of the layer - subblock: block defining the next layer in the network. + down_path: encoding half of the layer + up_path: decoding half of the layer + subblock: block defining the next layer in the network. + Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` """ @@ -227,9 +228,9 @@ def _get_down_layer(self, in_channels: int, out_channels: int, is_top: bool) -> without maxpooling first. Args: - in_channels: number of input channels. - out_channels: number of output channels. - is_top: True if this is the top block. + in_channels: number of input channels. + out_channels: number of output channels. + is_top: True if this is the top block. """ mod: Convolution | nn.Sequential @@ -263,8 +264,8 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: Bottom layer (bottleneck) in voxelmorph consists of a typical down layer followed by an upsample layer. Args: - in_channels: number of input channels. - out_channels: number of output channels. + in_channels: number of input channels. + out_channels: number of output channels. """ mod: nn.Module @@ -291,9 +292,9 @@ def _get_up_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn without upsampling. Args: - in_channels: number of input channels. - out_channels: number of output channels. - is_top: True if this is the top block. + in_channels: number of input channels. + out_channels: number of output channels. + is_top: True if this is the top block. """ mod: Convolution | nn.Sequential @@ -333,8 +334,7 @@ def _get_up_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn def forward(self, concatenated_pairs: torch.Tensor) -> torch.Tensor: x = self.net(concatenated_pairs) - assert isinstance(x, torch.Tensor) # won't pass mypy check without this line - return x + return x # type: ignore voxelmorphunet = VoxelMorphUNet @@ -372,51 +372,52 @@ class serves as a wrapper that concatenates the input pair of moving and fixed i interpolated again back to full resolution before being used to warp the moving image. Args: - spatial_dims: number of spatial dimensions. - backbone: a backbone network. - integration_steps: number of integration steps used for obtaining DDF from DVF via scaling-and-squaring. - Defaults to 7. If set to 0, the network will be non-diffeomorphic. - half_res: whether to perform integration on half resolution. Defaults to False. + backbone: a backbone network. + integration_steps: number of integration steps used for obtaining DDF from DVF via scaling-and-squaring. + Defaults to 7. If set to 0, the network will be non-diffeomorphic. + half_res: whether to perform integration on half resolution. Defaults to False. + spatial_dims: number of spatial dimensions, defaults to 3. Example:: - from monai.networks.nets import VoxelMorphUNet, VoxelMorph - - # The following example construct an instance of VoxelMorph that matches the original VoxelMorph paper - # https://arxiv.org/pdf/1809.05231.pdf - - # First, a backbone network is constructed. In this case, we use a VoxelMorphUNet as the backbone network. - backbone = VoxelMorphUNet( - spatial_dims=3, - in_channels=2, - unet_out_channels=32, - channels=(16, 32, 32, 32, 32, 32), # this indicates the down block at the top takes 16 channels as - # input, the corresponding up block at the top produces 32 - # channels as output, the second down block takes 32 channels as - # input, and the corresponding up block at the same level - # produces 32 channels as output, etc. - final_conv_channels=(16, 16) - ) + from monai.networks.nets import VoxelMorphUNet, VoxelMorph + + # The following example construct an instance of VoxelMorph that matches the original VoxelMorph paper + # https://arxiv.org/pdf/1809.05231.pdf + + # First, a backbone network is constructed. In this case, we use a VoxelMorphUNet as the backbone network. + backbone = VoxelMorphUNet( + spatial_dims=3, + in_channels=2, + unet_out_channels=32, + channels=(16, 32, 32, 32, 32, 32), # this indicates the down block at the top takes 16 channels as + # input, the corresponding up block at the top produces 32 + # channels as output, the second down block takes 32 channels as + # input, and the corresponding up block at the same level + # produces 32 channels as output, etc. + final_conv_channels=(16, 16) + ) - # Then, a full VoxelMorph network is constructed using the specified backbone network. - net = VoxelMorph( - backbone=backbone, - integration_steps=7, - half_res=False - ) + # Then, a full VoxelMorph network is constructed using the specified backbone network. + net = VoxelMorph( + backbone=backbone, + integration_steps=7, + half_res=False + ) + + # A forward pass through the network would look something like this + moving = torch.randn(1, 1, 160, 192, 224) + fixed = torch.randn(1, 1, 160, 192, 224) + warped, ddf = net(moving, fixed) - # A forward pass through the network would look something like this - moving = torch.randn(1, 1, 160, 192, 224) - fixed = torch.randn(1, 1, 160, 192, 224) - warped, ddf = net(moving, fixed) """ def __init__( self, - spatial_dims: int, backbone: VoxelMorphUNet | nn.Module | None = None, integration_steps: int = 7, half_res: bool = False, + spatial_dims: int = 3, ) -> None: super().__init__() @@ -425,7 +426,7 @@ def __init__( backbone if backbone is not None else VoxelMorphUNet( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=2, unet_out_channels=32, channels=(16, 32, 32, 32, 32, 32), @@ -447,7 +448,7 @@ def __init__( def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if moving.shape != fixed.shape: raise ValueError( - f"The spatial shape of the moving image should be the same as the spatial shape of the fixed image." + "The spatial shape of the moving image should be the same as the spatial shape of the fixed image." f" Got {moving.shape} and {fixed.shape} instead." ) @@ -455,8 +456,8 @@ def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tens if x.shape[1] != self.spatial_dims: raise ValueError( - f"The number of channels in the output of the backbone network should be equal to the" - f" number of spatial dimensions. Got {x.shape[1]} channels instead." + "The number of channels in the output of the backbone network should be equal to the" + f" number of spatial dimensions {self.spatial_dims}. Got {x.shape[1]} channels instead." ) if x.shape[2:] != moving.shape[2:]: From ab9f595cd4d45ec28464ddd6f3329928fcf4d97f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 3 Nov 2023 10:11:45 +0000 Subject: [PATCH 11/11] update test case Signed-off-by: Wenqi Li --- tests/test_voxelmorph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index ce9a776032..c51f70cbf5 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -271,7 +271,7 @@ def test_ill_input_hyper_params(self, input_param): @parameterized.expand(ILL_CASES_IN_SHAPE) def test_ill_input_shape(self, input_param, moving_shape, fixed_shape): - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, RuntimeError)): net = VoxelMorph(**input_param).to(device) with eval_mode(net): _ = net.forward(torch.randn(moving_shape).to(device), torch.randn(fixed_shape).to(device))