Skip to content

RT-Detr 2d positional embedding bug #41379

@konstantinos-p

Description

@konstantinos-p

I believe that there is a subtle bug in how 2d positional embeddings are generated for the rt_detr model. Specifically, when generating a meshgrid and then flattening, the positional embeddings no longer match the intended image pixels. We can see this in the script below

import torch 
# Create 2d grid
x = torch.tensor([0, 1, 2, 3])
y = torch.tensor([0, 1, 2])

grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')

print(grid_x)
#tensor([[0, 0, 0],
#        [1, 1, 1],
#        [2, 2, 2],
#        [3, 3, 3]])
print(grid_y)
#tensor([[0, 1, 2],
#        [0, 1, 2],
#        [0, 1, 2],
#        [0, 1, 2]])

# Create image
image = torch.range(0,11).reshape((3,4)).int()
print(image)
#tensor([[ 0,  1,  2,  3],
#        [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]], dtype=torch.int32)

#Print flattened grid and image
print(grid_x.flatten())
# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
print(grid_y.flatten())
# tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
print(image.flatten())
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

This means that for images that are not square the embeddings don't encode the intended spatial information.

This issue comes from the following line. I believe setting indexing='xy' would have been the correct choice.

grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")

While RT-Detr was intended by the authors to work on square images, generalizing correctly to non-square images could be important for future research.

I will make a PR with a simple fix, that is consistent with the current positional embeddings but generalizes correctly to variable image sizes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions