@@ -7499,7 +7499,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
7499
7499
});
7500
7500
return std::make_pair (dst, new_vregs);
7501
7501
}
7502
- if (src.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
7502
+ if ((src.implicit_dim () == VectorLayout::ImplicitDim::kNone ||
7503
+ src.implicit_dim () == VectorLayout::ImplicitDim::kSecondMinor ) &&
7503
7504
dst_implicit_dim == VectorLayout::ImplicitDim::kMinor &&
7504
7505
src.bitwidth () == 32 && src.hasNativeTiling (ctx.target_shape )) {
7505
7506
// TODO(tlongeri): Make insertImplicitMinorDimension more flexible about
@@ -7517,6 +7518,14 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
7517
7518
insertImplicitMinorDimension (ctx, builder, loc, vregs,
7518
7519
src.implicitShape (vty.getShape ()), src,
7519
7520
dst.offsets ()));
7521
+ if (src.implicit_dim () == VectorLayout::ImplicitDim::kSecondMinor ) {
7522
+ // Remove the original implicit 2nd minor, now implicit 3rd minor
7523
+ SmallVector<int64_t > dst_vregs_shape (dst_vregs.dimensions ().begin (),
7524
+ dst_vregs.dimensions ().end ());
7525
+ CHECK_EQ (*(dst_vregs_shape.end () - 3 ), 1 );
7526
+ dst_vregs_shape.erase (dst_vregs_shape.end () - 3 );
7527
+ dst_vregs.Reshape (dst_vregs_shape);
7528
+ }
7520
7529
return std::make_pair (dst, std::move (dst_vregs));
7521
7530
}
7522
7531
if (src.implicit_dim () == VectorLayout::ImplicitDim::kMinor &&
0 commit comments