Skip to content

[mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA #65621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
Original file line number Diff line number Diff line change
@@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter {
Type convertMemRefToBarePtr(BaseMemRefType type) const;

/// Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type) const;
FailureOr<Type> convertVectorType(VectorType type) const;

/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
Original file line number Diff line number Diff line change
@@ -20,6 +20,8 @@
namespace mlir {
namespace arm_sme {

constexpr unsigned MinStreamingVectorLengthInBits = 128;

/// Return minimum number of elements for the given element `type` in
/// a vector of SVL bits.
unsigned getSMETileSliceMinNumElts(Type type);
19 changes: 11 additions & 8 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
@@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addConversion([&](MemRefType type) { return convertMemRefType(type); });
addConversion(
[&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
addConversion([&](VectorType type) { return convertVectorType(type); });
addConversion([&](VectorType type) -> std::optional<Type> {
FailureOr<Type> llvmType = convertVectorType(type);
if (failed(llvmType))
return std::nullopt;
return llvmType;
});

// LLVM-compatible types are legal, so add a pass-through conversion. Do this
// before the conversions below since conversions are attempted in reverse
@@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
/// * 1-D `vector<axT>` remains as is while,
/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
/// `!llvm.array<ax...array<jxvector<kxT>>>`.
/// As LLVM does not support arrays of scalable vectors, it is assumed that
/// scalable vectors are always 1-D. This condition could be relaxed once the
/// missing functionality is added in LLVM
Type LLVMTypeConverter::convertVectorType(VectorType type) const {
/// Returns failure for n-D scalable vector types as LLVM does not support
/// arrays of scalable vectors.
FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
auto elementType = convertType(type.getElementType());
if (!elementType)
return {};
@@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
assert(
(!type.isScalable() || (type.getRank() == 1)) &&
"expected 1-D scalable vector (n-D scalable vectors are not supported)");
if (type.isScalable() && (type.getRank() > 1))
return failure();
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
117 changes: 113 additions & 4 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
@@ -361,6 +361,112 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};

/// Lower `vector.outerproduct` to SME MOPA intrinsics.
///
/// Example:
///
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
/// : vector<[4]xf32>, vector<[4]xf32>
///
/// is converted to:
///
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two have the same name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.

How do we plan to implement proper support for this? I see two options:

  1. In one shot, we search for the two masks in the use-def chain and use them directly in the intrinsic. If there is any mask manipulation operation in-between, it should become dead, hopefully, and go away.
  2. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.

Thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two have the same name?

If you're referring to the predicate (?) that's because they're the same, both all active

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.

How do we plan to implement proper support for this? I see two options:

1. In one shot, we search for the two masks in the use-def chain and use them directly in the intrinsic. If there is any mask manipulation operation in-between, it should become dead, hopefully, and go away.

2. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.

Thoughts?

The only examples I've seen of masking (from grepping around the codebase) are where the mask is applied to the result of the outerproduct e.g. vector.mask { vector.outerproduct ... }, I just figured we'd need some way to correlate this to the inputs, but hadn't given it much thought.

Appreciate your input, I'll add a custom op that way there's more flexibility when it comes to masking, and will also look into how it would be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the expected complexity with generating correct masks, I am also leaning towards a custom op. Having said that, IMHO this PR is fine as is and we could iterate in the follow-up patches.

  1. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I guess that for this to work, we'd need something like `

 %res = arm_sme.op %rhs, %lhs <optional_mask_for_rhs_or_result> <optional_mask_for_lhs>

So, we'd allow 2 optional masks, both of which would be optional:

  • if only 1 mask is specified then this is a mask for the result (1 x 2D),
  • if 2 masks are specified then these are for the input vectors 2 x 1D),
  • if no masks are specified, then use ptrue (all lanes are active).

WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We perhaps could use vector.mask for masking the result so that we don't have to disambiguate the semantics based on the number of masks...

/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
/// vector<[4]xf32>) -> ()
///
/// Currently only supports FMOPA and BFMOPA (non-widening).
struct VectorOuterProductToArmSMELowering
: public ConvertOpToLLVMPattern<vector::OuterProductOp> {
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::OuterProductOp outerProductOp,
vector::OuterProductOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto isSupportedType = [](VectorType vectorType) {
// TODO: the FP outer product instruction variants are predicated on
// different features [1]:
//
// * FMOPA (non-widening)
// * half-precision - +sme2p1,+sme-f16f16
// * single-precision - +sme
// * double-precision - +sme-f64f64
// * BFMOPA
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
// [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;

auto elementType = vectorType.getElementType();

if (!elementType.isF16() && !elementType.isBF16() &&
!elementType.isF32() && !elementType.isF64())
return false;

unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
vectorType.getElementTypeBitWidth();
if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
return false;

return true;
};

auto resultVectorType = outerProductOp.getResultVectorType();
if (!isSupportedType(resultVectorType))
return outerProductOp.emitError("unsupported type");

vector::CombiningKind kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
// TODO: support subtract.
return outerProductOp.emitError("unsupported kind");

auto maskableOp =
cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
if (maskableOp.isMasked())
// TODO: support masking.
return outerProductOp.emitError("masking is currently unsupported");

if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
// AXPY operation not suited for SME.
return failure();

auto loc = outerProductOp.getLoc();

Value acc = outerProductOp.getAcc();
if (!acc)
// Initalize accumulator with zero.
acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);

unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(elementWidth), acc);

// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a ConstantMaskOp here

Copy link
Member

@MacDue MacDue Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConstantMaskOp can only make all false masks for scalable vectors.

// Only zero sizes are accepted here:
vector.constant_mask [0] : vector<[4]xi1>

Could maybe use CreateMaskOp, but I'm not sure if it's much simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like a fairly important gap for us to fill, but not necessarily in this patch.

loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy =
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);

auto tileI32 = castTileIDToI32(tileId, loc, rewriter);

// Create 'arm_sme.intr.mopa' outer product intrinsic.
rewriter.create<arm_sme::aarch64_sme_mopa>(
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
outerProductOp.getRhs());

// Create `CastTileToVectorOp` to use as the output.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
outerProductOp, resultVectorType, tileId);

return success();
}
};

} // namespace

void mlir::configureArmSMELegalizeForExportTarget(
@@ -374,8 +480,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();

// Mark 'func.func' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
@@ -405,7 +513,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
LoadTileSliceToArmSMELowering,
MoveVectorToTileSliceToArmSMELowering>(converter);
patterns
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
VectorOuterProductToArmSMELowering>(converter);
}
2 changes: 0 additions & 2 deletions mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
Original file line number Diff line number Diff line change
@@ -17,8 +17,6 @@
using namespace mlir;
using namespace mlir::arm_sme;

static constexpr unsigned MinStreamingVectorLengthInBits = 128;

unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) && "invalid tile type!");
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
Original file line number Diff line number Diff line change
@@ -1121,11 +1121,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {

LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
VectorType resType = op.getResultVectorType();
if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
return failure();

auto loc = op.getLoc();

VectorType lhsType = op.getOperandVectorTypeLHS();
VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
VectorType resType = op.getResultVectorType();
Type eltType = resType.getElementType();
bool isInt = isa<IntegerType, IndexType>(eltType);
Value acc = op.getAcc();
107 changes: 106 additions & 1 deletion mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s

//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @transfer_write_2d_zero_i8(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
@@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
return
}

//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//

// -----

// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
@@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
return %tile : vector<[1]x[1]xi128>
}

//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @vector_store_i8(
@@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}

//===----------------------------------------------------------------------===//
// vector.outerproduct
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @vector_outerproduct_add_f16
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
// CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
// CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_bf16
func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_f32
func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
// CHECK-NOT: arith.extui
// CHECK-NOT: arith.trunci
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_f64
func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
// CHECK: arith.trunci {{.*}} : i64 to i32
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_no_accumulator
func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
// CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
%0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
// CHECK-NOT: arm_sme
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
return %0 : vector<[2]xf64>
}

// -----

func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{unsupported type}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
}

// -----

func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{unsupported kind}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{masking is currently unsupported}}
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils

// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITHOUT-ACC

// REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_4x4xf32
// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC

llvm.func @printCString(!llvm.ptr<i8>)

func.func @printTileBegin() {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
: (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
return
}

func.func @printTileEnd() {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
: (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
return
}

func.func @test_outerproduct_no_accumulator_4x4xf32() {
%c0 = arith.constant 0 : index

%vector_i32 = llvm.intr.experimental.stepvector : vector<[4]xi32>
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>

// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
%vscale = vector.vscale
%min_elts_s = arith.constant 4 : index
%svl_s = arith.muli %min_elts_s, %vscale : index
%za_s_size = arith.muli %svl_s, %svl_s : index

// Allocate memory.
%mem = memref.alloca(%za_s_size) : memref<?xf32>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 4x4xf32.
//
// WITHOUT-ACC: TILE BEGIN
// WITHOUT-ACC-NEXT: ( 0, 0, 0, 0
// WITHOUT-ACC-NEXT: ( 0, 1, 2, 3
// WITHOUT-ACC-NEXT: ( 0, 2, 4, 6
// WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
// WITHOUT-ACC: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_s_size step %svl_s {
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
vector.print %tileslice : vector<[4]xf32>
}
func.call @printTileEnd() : () -> ()

return
}

func.func @test_outerproduct_with_accumulator_4x4xf32() {
%c0 = arith.constant 0 : index
%f10 = arith.constant 10.0 : f32

%acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
%vector_i32 = llvm.intr.experimental.stepvector : vector<[4]xi32>
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>

// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
%vscale = vector.vscale
%min_elts_s = arith.constant 4 : index
%svl_s = arith.muli %min_elts_s, %vscale : index
%za_s_size = arith.muli %svl_s, %svl_s : index

// Allocate memory.
%mem = memref.alloca(%za_s_size) : memref<?xf32>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 4x4xf32.
//
// WITH-ACC: TILE BEGIN
// WITH-ACC-NEXT: ( 10, 10, 10, 10
// WITH-ACC-NEXT: ( 10, 11, 12, 13
// WITH-ACC-NEXT: ( 10, 12, 14, 16
// WITH-ACC-NEXT: ( 10, 13, 16, 19
// WITH-ACC: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_s_size step %svl_s {
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
vector.print %tileslice : vector<[4]xf32>
}
func.call @printTileEnd() : () -> ()

return
}

llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// DEFINE: %{entry_point} = test_outerproduct_with_accumulator_2x2xf64
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme-f64f64 \
// DEFINE: -e %{entry_point} -entry-point-result=void \
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils

// RUN: %{compile} | %{run} | FileCheck %s

llvm.func @printCString(!llvm.ptr<i8>)

func.func @printTileBegin() {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
: (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
return
}

func.func @printTileEnd() {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
: (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
return
}

func.func @test_outerproduct_with_accumulator_2x2xf64() {
%c0 = arith.constant 0 : index
%f1 = arith.constant 1.0 : f64
%f2 = arith.constant 2.0 : f64
%f10 = arith.constant 10.0 : f64

%a = vector.splat %f1 : vector<[2]xf64>
%b = vector.splat %f2 : vector<[2]xf64>
// TODO: vector.splat doesn't support ArmSME.
%c = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>

%tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>

// Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
%vscale = vector.vscale
%min_elts_d = arith.constant 2 : index
%svl_d = arith.muli %min_elts_d, %vscale : index
%za_d_size = arith.muli %svl_d, %svl_d : index

// Allocate memory.
%mem = memref.alloca(%za_d_size) : memref<?xf64>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 2x2xf64.
//
// CHECK: TILE BEGIN
// CHECK-NEXT: ( 12, 12
// CHECK-NEXT: ( 12, 12
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_d_size step %svl_d {
%tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
vector.print %tileslice : vector<[2]xf64>
}
func.call @printTileEnd() : () -> ()

return
}

llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
llvm.mlir.global internal constant @str_tile_end("TILE END\0A")