Skip to content

Commit c040d85

Browse files
committed
removed from transform functions
1 parent 960c59b commit c040d85

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

monai/transforms/utility/array.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
optional_import,
6767
)
6868
from monai.utils.enums import TransformBackends
69-
from monai.utils.misc import is_module_ver_at_least
7069
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype
7170

7271
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
@@ -939,19 +938,10 @@ def __call__(
939938
data = img[[*select_labels]]
940939
else:
941940
where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore
942-
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
943-
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
944-
# pre pytorch 1.8.0, need to use 1/0 instead of True/False
945-
else:
946-
data = where(
947-
in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device)
948-
).reshape(img.shape)
941+
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
949942

950943
if merge_channels or self.merge_channels:
951-
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
952-
return data.any(0)[None]
953-
# pre pytorch 1.8.0 compatibility
954-
return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore
944+
return data.any(0)[None]
955945

956946
return data
957947

monai/transforms/utils_pytorch_numpy_unification.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919

2020
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
21-
from monai.utils.misc import is_module_ver_at_least
2221
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
2322

2423
__all__ = [
@@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
215214
Element-wise floor division between two arrays/tensors.
216215
"""
217216
if isinstance(a, torch.Tensor):
218-
if is_module_ver_at_least(torch, (1, 8, 0)):
219-
return torch.div(a, b, rounding_mode="floor")
220217
return torch.floor_divide(a, b)
221-
return np.floor_divide(a, b)
218+
else:
219+
return np.floor_divide(a, b)
222220

223221

224222
def unravel_index(idx, shape) -> NdarrayOrTensor:

0 commit comments

Comments
 (0)