diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 5cb4d5bef1a..947f098dd0f 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1398,18 +1398,7 @@ def _crop_bounding_boxes_dispatch( @_register_kernel_internal(crop, tv_tensors.Mask) def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: - if mask.ndim < 3: - mask = mask.unsqueeze(0) - needs_squeeze = True - else: - needs_squeeze = False - - output = crop_image(mask, top, left, height, width) - - if needs_squeeze: - output = output.squeeze(0) - - return output + return crop_image(mask, top, left, height, width) @_register_kernel_internal(crop, tv_tensors.Video) @@ -2036,18 +2025,7 @@ def _center_crop_bounding_boxes_dispatch( @_register_kernel_internal(center_crop, tv_tensors.Mask) def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: - if mask.ndim < 3: - mask = mask.unsqueeze(0) - needs_squeeze = True - else: - needs_squeeze = False - - output = center_crop_image(image=mask, output_size=output_size) - - if needs_squeeze: - output = output.squeeze(0) - - return output + return center_crop_image(image=mask, output_size=output_size) @_register_kernel_internal(center_crop, tv_tensors.Video)