-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Description
I believe that there is a subtle bug in how 2d positional embeddings are generated for the rt_detr model. Specifically, when generating a meshgrid and then flattening, the positional embeddings no longer match the intended image pixels. We can see this in the script below
import torch
# Create 2d grid
x = torch.tensor([0, 1, 2, 3])
y = torch.tensor([0, 1, 2])
grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
print(grid_x)
#tensor([[0, 0, 0],
# [1, 1, 1],
# [2, 2, 2],
# [3, 3, 3]])
print(grid_y)
#tensor([[0, 1, 2],
# [0, 1, 2],
# [0, 1, 2],
# [0, 1, 2]])
# Create image
image = torch.range(0,11).reshape((3,4)).int()
print(image)
#tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]], dtype=torch.int32)
#Print flattened grid and image
print(grid_x.flatten())
# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
print(grid_y.flatten())
# tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
print(image.flatten())
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
This means that for images that are not square the embeddings don't encode the intended spatial information.
This issue comes from the following line. I believe setting indexing='xy' would have been the correct choice.
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") |
While RT-Detr was intended by the authors to work on square images, generalizing correctly to non-square images could be important for future research.
I will make a PR with a simple fix, that is consistent with the current positional embeddings but generalizes correctly to variable image sizes.