Skip to content

Commit 1999394

Browse files
committed
[mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA
This patch adds support for lowering vector.outerproduct to the ArmSME MOPA intrinsic for the following types: vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16> vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16> vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32> vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64> The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA (non-widening) [2]. Note at the ISA level these variants are implemented by different architecture features, these are listed below: FMOPA (non-widening) * half-precision - +sme2p1,+sme-f16f16 * single-precision - +sme * double-precision - +sme-f64f64 BFMOPA (non-widening) * half-precision - +sme2p1,+b16b16 There's currently no way to target different features when lowering to ArmSME. Integration tests are added for F32 and F64. We use QEMU to run the integration tests but SME2 support isn't available yet, it's targeted for 9.0, so integration tests for these variants excluded. Masking is currently unsupported. Depends on llvm#65450. [1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate- [2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
1 parent ca42809 commit 1999394

File tree

7 files changed

+367
-8
lines changed

7 files changed

+367
-8
lines changed

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
namespace mlir {
2121
namespace arm_sme {
2222

23+
constexpr unsigned MinStreamingVectorLengthInBits = 128;
24+
2325
/// Return minimum number of elements for the given element `type` in
2426
/// a vector of SVL bits.
2527
unsigned getSMETileSliceMinNumElts(Type type);

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,111 @@ struct MoveVectorToTileSliceToArmSMELowering
361361
}
362362
};
363363

364+
/// Lower `vector.outerproduct` to SME MOPA intrinsics.
365+
///
366+
/// Example:
367+
///
368+
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
369+
/// : vector<[4]xf32>, vector<[4]xf32>
370+
///
371+
/// is converted to:
372+
///
373+
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
374+
/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
375+
/// vector<[4]xf32>) -> ()
376+
///
377+
/// Currently only supports FMOPA and BFMOPA (non-widening).
378+
struct VectorOuterProductToArmSMELowering
379+
: public ConvertOpToLLVMPattern<vector::OuterProductOp> {
380+
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
381+
382+
LogicalResult
383+
matchAndRewrite(vector::OuterProductOp outerProductOp,
384+
vector::OuterProductOp::Adaptor adaptor,
385+
ConversionPatternRewriter &rewriter) const override {
386+
auto isSupportedType = [](VectorType vectorType) {
387+
// TODO: the FP outer product instruction variants are predicated on
388+
// different features:
389+
//
390+
// * FMOPA (non-widening)
391+
// * half-precision - +sme2p1,+sme-f16f16
392+
// * single-precision - +sme
393+
// * double-precision - +sme-f64f64
394+
// * BFMOPA
395+
// * half-precision - +sme2p1,+b16b16
396+
//
397+
// It should be possible to control lowering based on target features.
398+
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
399+
return false;
400+
401+
auto elementType = vectorType.getElementType();
402+
403+
if (!elementType.isF16() && !elementType.isBF16() &&
404+
!elementType.isF32() && !elementType.isF64())
405+
return false;
406+
407+
unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
408+
vectorType.getElementTypeBitWidth();
409+
if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
410+
return false;
411+
412+
return true;
413+
};
414+
415+
auto resultVectorType = outerProductOp.getResultVectorType();
416+
if (!isSupportedType(resultVectorType))
417+
return outerProductOp.emitError("unsupported type");
418+
419+
vector::CombiningKind kind = outerProductOp.getKind();
420+
if (kind != vector::CombiningKind::ADD)
421+
// TODO: support subtract.
422+
return outerProductOp.emitError("unsupported kind");
423+
424+
auto maskableOp =
425+
cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
426+
if (maskableOp.isMasked())
427+
// TODO: support masking.
428+
return outerProductOp.emitError("masking is currently unsupported");
429+
430+
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
431+
// AXPY operation not suited for SME.
432+
return failure();
433+
434+
auto loc = outerProductOp.getLoc();
435+
436+
Value acc = outerProductOp.getAcc();
437+
if (!acc)
438+
// Initalize accumulator with zero.
439+
acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
440+
441+
unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
442+
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
443+
loc, rewriter.getIntegerType(elementWidth), acc);
444+
445+
// Create all active predicate mask.
446+
auto one = rewriter.create<arith::ConstantOp>(
447+
loc, rewriter.getI1Type(),
448+
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
449+
auto predTy =
450+
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
451+
/*scalableDims=*/{true});
452+
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
453+
454+
auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
455+
456+
// Create 'arm_sme.intr.mopa' outer product intrinsic.
457+
rewriter.create<arm_sme::aarch64_sme_mopa>(
458+
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
459+
outerProductOp.getRhs());
460+
461+
// Create `CastTileToVectorOp` to use as the output.
462+
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
463+
outerProductOp, resultVectorType, tileId);
464+
465+
return success();
466+
}
467+
};
468+
364469
} // namespace
365470

366471
void mlir::configureArmSMELegalizeForExportTarget(
@@ -374,8 +479,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
374479
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
375480
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
376481
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
377-
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
482+
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
483+
arm_sme::aarch64_sme_za_disable>();
378484
target.addLegalOp<GetTileID>();
485+
target.addIllegalOp<vector::OuterProductOp>();
379486

380487
// Mark 'func.func' ops as legal if either:
381488
// 1. no 'arm_za' function attribute is present.
@@ -405,7 +512,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
405512
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
406513
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
407514
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
408-
patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
409-
LoadTileSliceToArmSMELowering,
410-
MoveVectorToTileSliceToArmSMELowering>(converter);
515+
patterns
516+
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
517+
LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
518+
VectorOuterProductToArmSMELowering>(converter);
411519
}

mlir/lib/Dialect/ArmSME/Utils/Utils.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
using namespace mlir;
1818
using namespace mlir::arm_sme;
1919

20-
static constexpr unsigned MinStreamingVectorLengthInBits = 128;
21-
2220
unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
2321
assert(isValidSMETileElementType(type) && "invalid tile type!");
2422
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1121,11 +1121,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
11211121

11221122
LogicalResult matchAndRewrite(vector::OuterProductOp op,
11231123
PatternRewriter &rewriter) const override {
1124+
VectorType resType = op.getResultVectorType();
1125+
if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1126+
return failure();
1127+
11241128
auto loc = op.getLoc();
11251129

11261130
VectorType lhsType = op.getOperandVectorTypeLHS();
11271131
VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1128-
VectorType resType = op.getResultVectorType();
11291132
Type eltType = resType.getElementType();
11301133
bool isInt = isa<IntegerType, IndexType>(eltType);
11311134
Value acc = op.getAcc();

mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
2+
3+
//===----------------------------------------------------------------------===//
4+
// vector.transfer_write
5+
//===----------------------------------------------------------------------===//
26

37
// CHECK-LABEL: @transfer_write_2d_zero_i8(
48
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
@@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
3337
return
3438
}
3539

40+
//===----------------------------------------------------------------------===//
41+
// vector.load
42+
//===----------------------------------------------------------------------===//
43+
3644
// -----
3745

3846
// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
@@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
232240
return %tile : vector<[1]x[1]xi128>
233241
}
234242

243+
//===----------------------------------------------------------------------===//
244+
// vector.store
245+
//===----------------------------------------------------------------------===//
246+
235247
// -----
236248

237249
// CHECK-LABEL: @vector_store_i8(
@@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
391403
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
392404
return
393405
}
406+
407+
//===----------------------------------------------------------------------===//
408+
// vector.outerproduct
409+
//===----------------------------------------------------------------------===//
410+
411+
// -----
412+
413+
// CHECK-LABEL: @vector_outerproduct_add_f16
414+
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
415+
func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
416+
// CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
417+
// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
418+
// CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
419+
// CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
420+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
421+
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
422+
}
423+
424+
// -----
425+
426+
// CHECK-LABEL: @vector_outerproduct_add_bf16
427+
func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
428+
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
429+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
430+
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
431+
}
432+
433+
// -----
434+
435+
// CHECK-LABEL: @vector_outerproduct_add_f32
436+
func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
437+
// CHECK-NOT: arith.extui
438+
// CHECK-NOT: arith.trunci
439+
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
440+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
441+
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
442+
}
443+
444+
// -----
445+
446+
// CHECK-LABEL: @vector_outerproduct_add_f64
447+
func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
448+
// CHECK: arith.trunci {{.*}} : i64 to i32
449+
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
450+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
451+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
452+
}
453+
454+
// -----
455+
456+
// CHECK-LABEL: @vector_outerproduct_no_accumulator
457+
func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
458+
// CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
459+
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
460+
%0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
461+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
462+
}
463+
464+
// -----
465+
466+
// CHECK-LABEL: @vector_outerproduct_scalar_rhs
467+
func.func @vector_outerproduct_scalar_rhs(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
468+
// CHECK-NOT: arm_sme
469+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
470+
return %0 : vector<[2]xf64>
471+
}
472+
473+
// -----
474+
475+
func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
476+
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
477+
// expected-error@+1 {{unsupported type}}
478+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
479+
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
480+
}
481+
482+
// -----
483+
484+
func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
485+
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
486+
// expected-error@+1 {{unsupported kind}}
487+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
488+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
489+
}
490+
491+
// -----
492+
493+
func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
494+
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
495+
// expected-error@+1 {{masking is currently unsupported}}
496+
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
497+
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
498+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// DEFINE: %{entry_point} = test_outerproduct_4x4xf32
2+
// DEFINE: %{compile} = mlir-opt %s \
3+
// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
4+
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
5+
// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
6+
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
7+
// DEFINE: %{run} = %mcr_aarch64_cmd \
8+
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
9+
// DEFINE: -e %{entry_point} -entry-point-result=void \
10+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
11+
12+
// RUN: %{compile} | %{run} | FileCheck %s
13+
14+
// REDEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
15+
// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-NO-ACC
16+
17+
func.func @test_outerproduct_4x4xf32() {
18+
%c0 = arith.constant 0 : index
19+
%f1 = arith.constant 1.0 : f32
20+
%f2 = arith.constant 2.0 : f32
21+
%f10 = arith.constant 10.0 : f32
22+
23+
%a = vector.splat %f1 : vector<[4]xf32>
24+
%b = vector.splat %f2 : vector<[4]xf32>
25+
// TODO: vector.splat doesn't support ArmSME.
26+
%c = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
27+
28+
%tile = vector.outerproduct %a, %b, %c : vector<[4]xf32>, vector<[4]xf32>
29+
30+
// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
31+
%vscale = vector.vscale
32+
%min_elts_s = arith.constant 4 : index
33+
%svl_s = arith.muli %min_elts_s, %vscale : index
34+
%za_s_size = arith.muli %svl_s, %svl_s : index
35+
36+
// Allocate memory.
37+
%mem = memref.alloca(%za_s_size) : memref<?xf32>
38+
39+
// Store the tile to memory.
40+
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
41+
42+
// Reload and print. The smallest SVL is 128-bits so the tile will be at
43+
// least 4x4xf32.
44+
//
45+
// CHECK: ( 12, 12, 12, 12
46+
// CHECK-NEXT: ( 12, 12, 12, 12
47+
// CHECK-NEXT: ( 12, 12, 12, 12
48+
// CHECK-NEXT: ( 12, 12, 12, 12
49+
scf.for %i = %c0 to %za_s_size step %svl_s {
50+
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
51+
vector.print %tileslice : vector<[4]xf32>
52+
}
53+
54+
return
55+
}
56+
57+
func.func @test_outerproduct_no_accumulator_4x4xf32() {
58+
%c0 = arith.constant 0 : index
59+
%f1 = arith.constant 1.0 : f32
60+
%f2 = arith.constant 2.0 : f32
61+
%f10 = arith.constant 10.0 : f32
62+
63+
%a = vector.splat %f1 : vector<[4]xf32>
64+
%b = vector.splat %f2 : vector<[4]xf32>
65+
66+
%tile = vector.outerproduct %a, %b : vector<[4]xf32>, vector<[4]xf32>
67+
68+
// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
69+
%vscale = vector.vscale
70+
%min_elts_s = arith.constant 4 : index
71+
%svl_s = arith.muli %min_elts_s, %vscale : index
72+
%za_s_size = arith.muli %svl_s, %svl_s : index
73+
74+
// Allocate memory.
75+
%mem = memref.alloca(%za_s_size) : memref<?xf32>
76+
77+
// Store the tile to memory.
78+
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
79+
80+
// Reload and print. The smallest SVL is 128-bits so the tile will be at
81+
// least 4x4xf32.
82+
//
83+
// CHECK-NO-ACC: ( 2, 2, 2, 2
84+
// CHECK-NO-ACC-NEXT: ( 2, 2, 2, 2
85+
// CHECK-NO-ACC-NEXT: ( 2, 2, 2, 2
86+
// CHECK-NO-ACC-NEXT: ( 2, 2, 2, 2
87+
scf.for %i = %c0 to %za_s_size step %svl_s {
88+
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
89+
vector.print %tileslice : vector<[4]xf32>
90+
}
91+
92+
return
93+
}

0 commit comments

Comments
 (0)