diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py
index 5cb4d5bef1a..947f098dd0f 100644
--- a/torchvision/transforms/v2/functional/_geometry.py
+++ b/torchvision/transforms/v2/functional/_geometry.py
@@ -1398,18 +1398,7 @@ def _crop_bounding_boxes_dispatch(
 
 @_register_kernel_internal(crop, tv_tensors.Mask)
 def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
-    if mask.ndim < 3:
-        mask = mask.unsqueeze(0)
-        needs_squeeze = True
-    else:
-        needs_squeeze = False
-
-    output = crop_image(mask, top, left, height, width)
-
-    if needs_squeeze:
-        output = output.squeeze(0)
-
-    return output
+    return crop_image(mask, top, left, height, width)
 
 
 @_register_kernel_internal(crop, tv_tensors.Video)
@@ -2036,18 +2025,7 @@ def _center_crop_bounding_boxes_dispatch(
 
 @_register_kernel_internal(center_crop, tv_tensors.Mask)
 def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
-    if mask.ndim < 3:
-        mask = mask.unsqueeze(0)
-        needs_squeeze = True
-    else:
-        needs_squeeze = False
-
-    output = center_crop_image(image=mask, output_size=output_size)
-
-    if needs_squeeze:
-        output = output.squeeze(0)
-
-    return output
+    return center_crop_image(image=mask, output_size=output_size)
 
 
 @_register_kernel_internal(center_crop, tv_tensors.Video)