Skip to content

Commit 42bae57

Browse files
Fix hard clamping
1 parent c6b365b commit 42bae57

File tree

3 files changed

+124
-54
lines changed

3 files changed

+124
-54
lines changed

test/common_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -468,18 +468,6 @@ def sample_position(values, max_value):
468468
else:
469469
raise ValueError(f"Format {format} is not supported")
470470
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
471-
if tv_tensors.is_rotated_bounding_format(format):
472-
# Rotated bounding boxes are not inherently confined within the canvas, so clamping is applied.
473-
# Transform tests allow a 2-pixel tolerance relative to the canvas size.
474-
# To prevent discrepancies when clamping with different canvas sizes, we add a 2-pixel buffer.
475-
buffer = 4
476-
out_boxes = clamp_bounding_boxes(
477-
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer)
478-
)
479-
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
480-
out_boxes[:, :2] += buffer // 2
481-
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
482-
out_boxes[:, :] += buffer // 2
483471
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size)
484472

485473

test/test_transforms_v2.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,7 +1298,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B
12981298
)
12991299

13001300
helper = (
1301-
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True)
1301+
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True, clamp=False)
13021302
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
13031303
else reference_affine_bounding_boxes_helper
13041304
)
@@ -1907,7 +1907,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou
19071907
)
19081908

19091909
helper = (
1910-
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True)
1910+
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True, clamp=False)
19111911
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
19121912
else reference_affine_bounding_boxes_helper
19131913
)
@@ -2196,7 +2196,7 @@ def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
21962196
(bounding_boxes.to(torch.float64) - torch.tensor(translate)).to(bounding_boxes.dtype), like=bounding_boxes
21972197
)
21982198

2199-
def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center):
2199+
def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center, canvas_size=None):
22002200
if center is None:
22012201
center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
22022202
cx, cy = center
@@ -2222,7 +2222,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
22222222
output = helper(
22232223
bounding_boxes,
22242224
affine_matrix=affine_matrix,
2225-
new_canvas_size=new_canvas_size,
2225+
new_canvas_size=new_canvas_size if canvas_size is None else canvas_size,
22262226
clamp=False,
22272227
)
22282228

@@ -2239,9 +2239,10 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent
22392239

22402240
actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center)
22412241
expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center)
2242+
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
22422243

2244+
expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center, canvas_size=actual.canvas_size)
22432245
torch.testing.assert_close(actual, expected)
2244-
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
22452246

22462247
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
22472248
@pytest.mark.parametrize("expand", [False, True])
@@ -2259,9 +2260,10 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed
22592260
actual = transform(bounding_boxes)
22602261

22612262
expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center)
2262-
2263-
torch.testing.assert_close(actual, expected)
22642263
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
2264+
2265+
expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center, canvas_size=actual.canvas_size)
2266+
torch.testing.assert_close(actual, expected)
22652267

22662268
def _recenter_keypoints_after_expand(self, keypoints, *, recenter_xy):
22672269
x, y = recenter_xy
@@ -4437,7 +4439,7 @@ def test_functional_bounding_boxes_correctness(self, format):
44374439
bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE
44384440
)
44394441

4440-
torch.testing.assert_close(actual, expected)
4442+
torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5)
44414443
assert_equal(F.get_size(actual), F.get_size(expected))
44424444

44434445
def _reference_resized_crop_keypoints(self, keypoints, *, top, left, height, width, size):

torchvision/transforms/v2/functional/_meta.py

Lines changed: 114 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -410,16 +410,87 @@ def _order_bounding_boxes_points(
410410
output_xyxyxyxy = bounding_boxes.reshape(-1, 8)
411411
x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2]
412412
y_max = torch.max(y.abs(), dim=1, keepdim=True)[0]
413-
_, x1 = (y / y_max + (x + 1) * 100).min(dim=1)
413+
x_max = torch.max(x.abs(), dim=1, keepdim=True)[0]
414+
_, x1 = (y / y_max + (x / x_max) * 100).min(dim=1)
414415
indices = torch.ones_like(output_xyxyxyxy)
415416
indices[..., 0] = x1.mul(2)
416417
indices.cumsum_(1).remainder_(8)
417418
return indices, bounding_boxes.gather(1, indices.to(torch.int64))
418419

419420

421+
def _get_slope_and_intercept(box: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
422+
"""
423+
Calculate the slope and y-intercept of the lines defined by consecutive vertices in a bounding box.
424+
This function computes the slope (a) and y-intercept (b) for each line segment in a bounding box,
425+
where each line is defined by two consecutive vertices.
426+
"""
427+
x, y = box[..., ::2], box[..., 1::2]
428+
a = y.diff(append=y[..., 0:1]) / x.diff(append=x[..., 0:1])
429+
b = y - a * x
430+
return a, b
431+
432+
433+
def _get_intersection_point(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
434+
"""
435+
Calculate the intersection point of two lines defined by their slopes and y-intercepts.
436+
This function computes the intersection points between pairs of lines, where each line
437+
is defined by the equation y = ax + b (slope and y-intercept form).
438+
"""
439+
batch_size = a.shape[0]
440+
x = b.diff(prepend=b[..., 3:4]).neg() / a.diff(prepend=a[..., 3:4])
441+
y = a * x + b
442+
return torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), dim=-1).view(batch_size, 8)
443+
444+
445+
def _clamp_y_intercept(
446+
bounding_boxes: torch.Tensor,
447+
original_bounding_boxes: torch.Tensor,
448+
canvas_size: tuple[int, int],
449+
clamping: str = "hard",
450+
) -> torch.Tensor:
451+
"""
452+
Apply clamping to bounding box y-intercepts. This function handles two clamping strategies:
453+
- Hard clamping: Ensures all box vertices stay within canvas boundaries, finding the largest
454+
angle-preserving box enclosed within the original box and the image canvas.
455+
- Soft clamping: Allows some vertices to extend beyond the canvas, finding the smallest
456+
angle-preserving box that encloses the intersection of the original box and the image canvas.
457+
458+
The function first calculates the slopes and y-intercepts of the lines forming the bounding box,
459+
then applies various constraints to ensure the clamping conditions are respected.
460+
"""
461+
462+
a, b = _get_slope_and_intercept(bounding_boxes)
463+
a1, a2, a3, a4 = a.unbind(-1)
464+
b1, b2, b3, b4 = b.unbind(-1)
465+
466+
# Clamp y-intercepts (soft clamping)
467+
b1 = b2.clamp(0).clamp(b1, b3)
468+
b4 = b3.clamp(max=canvas_size[0]).clamp(b2, b4)
469+
470+
if clamping == "hard":
471+
# Get y-intercepts from original bounding boxes
472+
_, b = _get_slope_and_intercept(original_bounding_boxes)
473+
_, b2, b3, _ = b.unbind(-1)
474+
475+
# Set b1 and b4 to the average of their clamped values
476+
b1 = b4 = (b1.clamp(0, canvas_size[0]) + b4.clamp(0, canvas_size[0])) / 2
477+
478+
# Ensure b2 and b3 defined the box of maximum area after clamping b1 and b4
479+
b2.clamp_(b1 * a2 / a1, b4).clamp_((a1 - a2) * canvas_size[1] + b1)
480+
b2.clamp_(b3 * a2 / a3, b4).clamp_((a3 - a2) * canvas_size[1] + b3)
481+
b3.clamp_(max=canvas_size[0] * (1 - a3 / a4) + b4 * a3 / a4)
482+
b3.clamp_(max=canvas_size[0] * (1 - a3 / a2) + b2 * a3 / a2)
483+
b3.clamp_(b1, (a2 - a3) * canvas_size[1] + b2)
484+
b3.clamp_(b1, (a4 - a3) * canvas_size[1] + b4)
485+
486+
return torch.stack([b1, b2, b3, b4], dim=-1)
487+
488+
420489
def _clamp_along_y_axis(
421490
bounding_boxes: torch.Tensor,
491+
original_bounding_boxes: torch.Tensor,
422492
canvas_size: tuple[int, int],
493+
clamping: str = "hard",
423494
) -> torch.Tensor:
424495
"""
425496
Adjusts bounding boxes along the y-axis based on specific conditions.
@@ -430,52 +501,53 @@ def _clamp_along_y_axis(
430501
431502
Args:
432503
bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates.
504+
original_bounding_boxes (torch.Tensor): The original bounding boxes before any clamping is applied.
505+
canvas_size (tuple[int, int]): The size of the canvas as (height, width).
506+
clamping (str, optional): The clamping strategy to use. Defaults to "hard".
433507
434508
Returns:
435509
torch.Tensor: The adjusted bounding boxes.
436510
"""
437-
original_dtype = bounding_boxes.dtype
511+
dtype = bounding_boxes.dtype
512+
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
513+
need_cast = dtype not in acceptable_dtypes
514+
eps = 1e-06 # Ensure consistency between CPU and GPU.
438515
original_shape = bounding_boxes.shape
439-
x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.reshape(-1, 8).unbind(-1)
440-
a = (y2 - y1) / (x2 - x1)
441-
b1 = y1 - a * x1
442-
b2 = y2 + x2 / a
443-
b3 = y3 - a * x3
444-
b4 = y4 + x4 / a
445-
c = a / (1 + a**2)
446-
b1 = b2.clamp(0).clamp(b1, b3)
447-
b4 = b3.clamp(max=canvas_size[0]).clamp(b2, b4)
448-
case_a = torch.stack(
449-
(
450-
(b4 - b1) * c,
451-
(b4 - b1) * c * a + b1,
452-
(b2 - b1) * c,
453-
(b1 - b2) * c / a + b2,
454-
x3,
455-
y3,
456-
(b4 - b3) * c,
457-
(b3 - b4) * c / a + b4,
458-
),
459-
dim=-1,
460-
)
516+
bounding_boxes = bounding_boxes.reshape(-1, 8)
517+
original_bounding_boxes = original_bounding_boxes.reshape(-1, 8)
518+
519+
# Calculate slopes (a) and y-intercepts (b) for all lines in the bounding boxes
520+
a, b = _get_slope_and_intercept(bounding_boxes)
521+
x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.unbind(-1)
522+
b = _clamp_y_intercept(bounding_boxes, original_bounding_boxes, canvas_size, clamping)
523+
524+
case_a = _get_intersection_point(a, b)
461525
case_b = bounding_boxes.clone()
462-
case_b[..., 0].clamp_(0)
463-
case_b[..., 6].clamp_(0)
526+
case_b[..., 0].clamp_(0) # Clamp x1 to 0
527+
case_b[..., 6].clamp_(0) # Clamp x4 to 0
464528
case_c = torch.zeros_like(case_b)
465529

466-
cond_a = x1 < 0
467-
cond_b = y1.isclose(y2, rtol=1e-05, atol=1e-05)
468-
cond_c = (x1 <= 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0)
469-
for cond, case in zip(
530+
cond_a = (x1 < eps) & ~case_a.isnan().any(-1) # First point is outside left boundary
531+
cond_b = y1.isclose(y2, rtol=eps, atol=eps) | y3.isclose(y4, rtol=eps, atol=eps) # First line is nearly vertical
532+
cond_c = (x1 <= 0) & (x2 <= 0) & (x3 <= 0) & (x4 <= 0) # All points outside left boundary
533+
cond_c = cond_c | y1.isclose(y4, rtol=eps, atol=eps) | y2.isclose(y3, rtol=eps, atol=eps) | (cond_b & x1.isclose(x2, rtol=eps, atol=eps)) # First line is nearly horizontal
534+
535+
for (cond, case) in zip(
470536
[cond_a, cond_b, cond_c],
471537
[case_a, case_b, case_c],
472538
):
473539
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
474-
return bounding_boxes.to(original_dtype).reshape(original_shape)
540+
bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0
541+
542+
if need_cast:
543+
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
544+
bounding_boxes.round_()
545+
bounding_boxes = bounding_boxes.to(dtype)
546+
return bounding_boxes.reshape(original_shape)
475547

476548

477549
def _clamp_rotated_bounding_boxes(
478-
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int]
550+
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int], clamping_mode: str = "soft"
479551
) -> torch.Tensor:
480552
"""
481553
Clamp rotated bounding boxes to ensure they stay within the canvas boundaries.
@@ -508,15 +580,22 @@ def _clamp_rotated_bounding_boxes(
508580
)
509581
).reshape(-1, 8)
510582

583+
original_boxes = out_boxes.clone()
511584
for _ in range(4): # Iterate over the 4 vertices.
512585
indices, out_boxes = _order_bounding_boxes_points(out_boxes)
513-
out_boxes = _clamp_along_y_axis(out_boxes, canvas_size)
586+
_, original_boxes = _order_bounding_boxes_points(original_boxes, indices)
587+
out_boxes = _clamp_along_y_axis(out_boxes, original_boxes, canvas_size, clamping_mode)
514588
_, out_boxes = _order_bounding_boxes_points(out_boxes, indices)
589+
_, original_boxes = _order_bounding_boxes_points(original_boxes, indices)
515590
# rotate 90 degrees counter clock wise
516591
out_boxes[:, ::2], out_boxes[:, 1::2] = (
517592
out_boxes[:, 1::2].clone(),
518593
canvas_size[1] - out_boxes[:, ::2].clone(),
519594
)
595+
original_boxes[:, ::2], original_boxes[:, 1::2] = (
596+
original_boxes[:, 1::2].clone(),
597+
canvas_size[1] - original_boxes[:, ::2].clone(),
598+
)
520599
canvas_size = (canvas_size[1], canvas_size[0])
521600

522601
out_boxes = convert_bounding_box_format(
@@ -525,7 +604,8 @@ def _clamp_rotated_bounding_boxes(
525604

526605
if need_cast:
527606
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
528-
out_boxes.round_()
607+
# Adding epsilon to ensure consistency between CPU and GPU rounding.
608+
out_boxes.add_(1e-7).round_()
529609
out_boxes = out_boxes.to(dtype)
530610
return out_boxes
531611

0 commit comments

Comments
 (0)