diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 3af454a6bf52..a440c1867368 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -175,6 +175,7 @@ def remove_dimension(self, dim: int) -> Tiling: for t in self.tiles: assert last_tile_rank >= len(t) dim_in_tile -= last_tile_rank - len(t) + last_tile_rank = len(t) if dim_in_tile >= 0: t = t[:dim_in_tile] + t[dim_in_tile + 1:] if not t: # If this tile is empty, all other tiles will be empty too. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 57510f090ced..01f436dbd246 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -20,8 +20,8 @@ import itertools import math import operator -import sys import re +import sys import unittest from absl.testing import absltest, parameterized @@ -2596,6 +2596,29 @@ def kernel(ctx, dst, _): def test_layout_reduction_definition(self, layout, expected_reduced_layout, axis): self.assertEqual(layout.reduce((axis,)), expected_reduced_layout) + def test_layout_reduction_handles_tiles_with_three_different_ranks(self): + layout = fa.TiledLayout( + tiling=fa.Tiling(tiles=((1, 2, 64), (2, 16), (8,), (4,), (2,), (1,))), + warp_dims=(-7,), + lane_dims=(-6, -5, -4, -3, -2), + vector_dim=-1, + ) + self.assertEqual( + layout.reduce((2,)), + fa.TiledLayout( + tiling=fa.Tiling(tiles=((1, 2), (1,))), + warp_dims=(fa.Replicated(times=4),), + lane_dims=( + -2, + fa.Replicated(times=2), + fa.Replicated(times=2), + fa.Replicated(times=2), + fa.Replicated(times=2), + ), + vector_dim=-1, + ), + ) + @parameterized.product( op=(arith.addf, arith.maximumf), m=(64, 128),