Skip to content

Commit 94a9fff

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU][Relayouts] 2nd minor implicit -> minor implicit
PiperOrigin-RevId: 777454930
1 parent e4c6cfe commit 94a9fff

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7499,7 +7499,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
74997499
});
75007500
return std::make_pair(dst, new_vregs);
75017501
}
7502-
if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
7502+
if ((src.implicit_dim() == VectorLayout::ImplicitDim::kNone ||
7503+
src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) &&
75037504
dst_implicit_dim == VectorLayout::ImplicitDim::kMinor &&
75047505
src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) {
75057506
// TODO(tlongeri): Make insertImplicitMinorDimension more flexible about
@@ -7517,6 +7518,14 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
75177518
insertImplicitMinorDimension(ctx, builder, loc, vregs,
75187519
src.implicitShape(vty.getShape()), src,
75197520
dst.offsets()));
7521+
if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) {
7522+
// Remove the original implicit 2nd minor, now implicit 3rd minor
7523+
SmallVector<int64_t> dst_vregs_shape(dst_vregs.dimensions().begin(),
7524+
dst_vregs.dimensions().end());
7525+
CHECK_EQ(*(dst_vregs_shape.end() - 3), 1);
7526+
dst_vregs_shape.erase(dst_vregs_shape.end() - 3);
7527+
dst_vregs.Reshape(dst_vregs_shape);
7528+
}
75207529
return std::make_pair(dst, std::move(dst_vregs));
75217530
}
75227531
if (src.implicit_dim() == VectorLayout::ImplicitDim::kMinor &&

0 commit comments

Comments
 (0)