Skip to content

Commit ee832c1

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Fix bug in Tiling.remove_dimension.
We forgot to update `last_tile_rank` when looping over tiles, which could lead to incorrect dimension offsets. PiperOrigin-RevId: 777473496
1 parent 94a9fff commit ee832c1

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def remove_dimension(self, dim: int) -> Tiling:
175175
for t in self.tiles:
176176
assert last_tile_rank >= len(t)
177177
dim_in_tile -= last_tile_rank - len(t)
178+
last_tile_rank = len(t)
178179
if dim_in_tile >= 0:
179180
t = t[:dim_in_tile] + t[dim_in_tile + 1:]
180181
if not t: # If this tile is empty, all other tiles will be empty too.

tests/mosaic/gpu_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import itertools
2121
import math
2222
import operator
23-
import sys
2423
import re
24+
import sys
2525
import unittest
2626

2727
from absl.testing import absltest, parameterized
@@ -2596,6 +2596,29 @@ def kernel(ctx, dst, _):
25962596
def test_layout_reduction_definition(self, layout, expected_reduced_layout, axis):
25972597
self.assertEqual(layout.reduce((axis,)), expected_reduced_layout)
25982598

2599+
def test_layout_reduction_handles_tiles_with_three_different_ranks(self):
2600+
layout = fa.TiledLayout(
2601+
tiling=fa.Tiling(tiles=((1, 2, 64), (2, 16), (8,), (4,), (2,), (1,))),
2602+
warp_dims=(-7,),
2603+
lane_dims=(-6, -5, -4, -3, -2),
2604+
vector_dim=-1,
2605+
)
2606+
self.assertEqual(
2607+
layout.reduce((2,)),
2608+
fa.TiledLayout(
2609+
tiling=fa.Tiling(tiles=((1, 2), (1,))),
2610+
warp_dims=(fa.Replicated(times=4),),
2611+
lane_dims=(
2612+
-2,
2613+
fa.Replicated(times=2),
2614+
fa.Replicated(times=2),
2615+
fa.Replicated(times=2),
2616+
fa.Replicated(times=2),
2617+
),
2618+
vector_dim=-1,
2619+
),
2620+
)
2621+
25992622
@parameterized.product(
26002623
op=(arith.addf, arith.maximumf),
26012624
m=(64, 128),

0 commit comments

Comments
 (0)