Skip to content

Commit 5f22e1a

Browse files
[mlir][Interfaces] LoopLikeOpInterface: Support ops with multiple regions
This commit implements `LoopLikeOpInterface` on `scf.while`. This enables LICM (and potentially other transforms) on `scf.while`. `LoopLikeOpInterface::getLoopBody()` can now return multiple regions. Also fix a bug in the default implementation of `LoopLikeOpInterface::isDefinedOutsideOfLoop()`, which returned "false" for some values that are defined outside of the loop (in a nested op, in such a way that the value does not dominate the loop). This interface is currently only used for LICM and there is no way to trigger this bug, so no test is added. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1 parent b05d436 commit 5f22e1a

File tree

23 files changed

+101
-71
lines changed

23 files changed

+101
-71
lines changed

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,7 +1947,9 @@ void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) {
19471947
/*printBlockTerminators=*/true);
19481948
}
19491949

1950-
mlir::Region &fir::IterWhileOp::getLoopBody() { return getRegion(); }
1950+
llvm::SmallVector<mlir::Region *> fir::IterWhileOp::getLoopRegions() {
1951+
return {&getRegion()};
1952+
}
19511953

19521954
mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) {
19531955
for (auto i : llvm::enumerate(getInitArgs()))
@@ -2234,7 +2236,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
22342236
printBlockTerminators);
22352237
}
22362238

2237-
mlir::Region &fir::DoLoopOp::getLoopBody() { return getRegion(); }
2239+
llvm::SmallVector<mlir::Region *> fir::DoLoopOp::getLoopRegions() {
2240+
return {&getRegion()};
2241+
}
22382242

22392243
/// Translate a value passed as an iter_arg to the corresponding block
22402244
/// argument in the body of the loop.

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,7 @@ def ReduceReturnOp :
958958
def WhileOp : SCF_Op<"while",
959959
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
960960
["getEntrySuccessorOperands"]>,
961+
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
961962
RecursiveMemoryEffects, SingleBlock]> {
962963
let summary = "a generic 'while' loop";
963964
let description = [{

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
3636
/*args=*/(ins "::mlir::Value ":$value),
3737
/*methodBody=*/"",
3838
/*defaultImplementation=*/[{
39-
return value.getParentRegion()->isProperAncestor(&$_op.getLoopBody());
39+
return !$_op->isAncestor(value.getParentRegion()->getParentOp());
4040
}]
4141
>,
4242
InterfaceMethod<[{
43-
Returns the region that makes up the body of the loop and should be
43+
Returns the regions that make up the body of the loop and should be
4444
inspected for loop-invariant operations.
4545
}],
46-
/*retTy=*/"::mlir::Region &",
47-
/*methodName=*/"getLoopBody"
46+
/*retTy=*/"::llvm::SmallVector<::mlir::Region *>",
47+
/*methodName=*/"getLoopRegions"
4848
>,
4949
InterfaceMethod<[{
5050
Moves the given loop-invariant operation out of the loop.

mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212
#include "mlir/Support/LLVM.h"
1313

14+
#include "llvm/ADT/SmallVector.h"
15+
1416
namespace mlir {
1517

1618
class LoopLikeOpInterface;
1719
class Operation;
1820
class Region;
19-
class RegionRange;
2021
class Value;
2122

2223
/// Given a list of regions, perform loop-invariant code motion. An operation is
@@ -61,7 +62,7 @@ class Value;
6162
///
6263
/// Returns the number of operations moved.
6364
size_t moveLoopInvariantCode(
64-
RegionRange regions,
65+
ArrayRef<Region *> regions,
6566
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
6667
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
6768
function_ref<void(Operation *, Region *)> moveOutOfRegion);

mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
163163
signatureConverter.remapInput(0, newIndVar);
164164
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
165165
signatureConverter.remapInput(i, header->getArgument(i));
166-
body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
166+
body = rewriter.applySignatureConversion(&forOp.getRegion(),
167167
signatureConverter);
168168

169169
// Move the blocks from the forOp into the loopOp. This is the body of the

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
11031103
}
11041104

11051105
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1106-
// updated and needs to be updated separatly for the loop to be correct.
1106+
// updated and needs to be updated separately for the loop to be correct.
11071107
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
11081108
scf::ForOp loop,
11091109
ValueRange newInitArgs) {
@@ -1119,9 +1119,8 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
11191119
operands);
11201120
newLoop.getBody()->erase();
11211121

1122-
newLoop.getLoopBody().getBlocks().splice(
1123-
newLoop.getLoopBody().getBlocks().begin(),
1124-
loop.getLoopBody().getBlocks());
1122+
newLoop.getRegion().getBlocks().splice(
1123+
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
11251124
for (Value operand : newInitArgs)
11261125
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
11271126

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,8 +2380,7 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
23802380
/// induction variable. AffineForOp only has one region, so zero is the only
23812381
/// valid value for `index`.
23822382
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2383-
assert((point.isParent() || point == getLoopBody()) &&
2384-
"invalid region point");
2383+
assert((point.isParent() || point == getRegion()) && "invalid region point");
23852384

23862385
// The initial operands map to the loop arguments after the induction
23872386
// variable or are forwarded to the results when the trip count is zero.
@@ -2395,16 +2394,15 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
23952394
/// not a constant.
23962395
void AffineForOp::getSuccessorRegions(
23972396
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2398-
assert((point.isParent() || point == getLoopBody()) &&
2399-
"expected loop region");
2397+
assert((point.isParent() || point == getRegion()) && "expected loop region");
24002398
// The loop may typically branch back to its body or to the parent operation.
24012399
// If the predecessor is the parent op and the trip count is known to be at
24022400
// least one, branch into the body using the iterator arguments. And in cases
24032401
// we know the trip count is zero, it can only branch back to its parent.
24042402
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
24052403
if (point.isParent() && tripCount.has_value()) {
24062404
if (tripCount.value() > 0) {
2407-
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
2405+
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
24082406
return;
24092407
}
24102408
if (tripCount.value() == 0) {
@@ -2422,7 +2420,7 @@ void AffineForOp::getSuccessorRegions(
24222420

24232421
// In all other cases, the loop may branch back to itself or the parent
24242422
// operation.
2425-
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
2423+
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
24262424
regions.push_back(RegionSuccessor(getResults()));
24272425
}
24282426

@@ -2561,7 +2559,7 @@ bool AffineForOp::matchingBoundOperandList() {
25612559
return true;
25622560
}
25632561

2564-
Region &AffineForOp::getLoopBody() { return getRegion(); }
2562+
SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
25652563

25662564
std::optional<Value> AffineForOp::getSingleInductionVar() {
25672565
return getInductionVar();
@@ -2758,9 +2756,9 @@ AffineForOp mlir::affine::replaceForOpWithNewYields(OpBuilder &b,
27582756
b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
27592757
loop.getStep(), operands);
27602758
// Take the body of the original parent loop.
2761-
newLoop.getLoopBody().takeBody(loop.getLoopBody());
2759+
newLoop.getRegion().takeBody(loop.getRegion());
27622760
for (Value val : newIterArgs)
2763-
newLoop.getLoopBody().addArgument(val.getType(), val.getLoc());
2761+
newLoop.getRegion().addArgument(val.getType(), val.getLoc());
27642762

27652763
// Update yield operation with new values to be added.
27662764
if (!newYieldedValues.empty()) {
@@ -3848,7 +3846,9 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
38483846
ensureTerminator(*bodyRegion, builder, result.location);
38493847
}
38503848

3851-
Region &AffineParallelOp::getLoopBody() { return getRegion(); }
3849+
SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3850+
return {&getRegion()};
3851+
}
38523852

38533853
unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
38543854

mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
8585
opsToHoist))
8686
return false;
8787
} else if (auto forOp = dyn_cast<AffineForOp>(op)) {
88-
if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, iterArgs,
88+
if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), indVar, iterArgs,
8989
opsWithUsers, opsToHoist))
9090
return false;
9191
} else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
92-
if (!areAllOpsInTheBlockListInvariant(parOp.getLoopBody(), indVar, iterArgs,
92+
if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), indVar, iterArgs,
9393
opsWithUsers, opsToHoist))
9494
return false;
9595
} else if (!isMemoryEffectFree(&op) &&

mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ static ParallelComputeFunction createParallelComputeFunction(
429429
mapping.map(op.getInductionVars(), computeBlockInductionVars);
430430
mapping.map(computeFuncType.captures, captures);
431431

432-
for (auto &bodyOp : op.getLoopBody().getOps())
432+
for (auto &bodyOp : op.getRegion().getOps())
433433
b.clone(bodyOp, mapping);
434434
};
435435
};
@@ -732,7 +732,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
732732

733733
// Make sure that all constants will be inside the parallel operation body to
734734
// reduce the number of parallel compute function arguments.
735-
cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
735+
cloneConstantsIntoTheRegion(op.getRegion(), rewriter);
736736

737737
// Compute trip count for each loop induction variable:
738738
// tripCount = ceil_div(upperBound - lowerBound, step);

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
219219
// Replace all uses of the `transferRead` with the corresponding
220220
// basic block argument.
221221
transferRead.getVector().replaceUsesWithIf(
222-
newForOp.getLoopBody().getArguments().back(),
223-
[&](OpOperand &use) {
222+
newForOp.getBody()->getArguments().back(), [&](OpOperand &use) {
224223
Operation *user = use.getOwner();
225224
return newForOp->isProperAncestor(user);
226225
});

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
199199
// Replace the index operations in the body of the innermost loop op.
200200
if (!loopOps.empty()) {
201201
auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
202-
for (IndexOp indexOp :
203-
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
204-
rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
202+
for (Region *r : loopOp.getLoopRegions())
203+
for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
204+
rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
205205
}
206206
}
207207

mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ static Operation *isTensorChunkAccessedByUnknownOp(Operation *writeOp,
303303
// pass-through tensor arguments left from previous level of
304304
// hoisting.
305305
if (auto forUser = dyn_cast<scf::ForOp>(user)) {
306-
Value arg = forUser.getLoopBody().getArgument(
306+
Value arg = forUser.getBody()->getArgument(
307307
use.getOperandNumber() - forUser.getNumControlOperands() +
308308
/*iv value*/ 1);
309309
uses.push_back(arg.getUses());

mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,9 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
152152
std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
153153
std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
154154
std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
155-
if (!inductionVar || !lowerBound || !singleStep) {
156-
LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb or step\n");
155+
if (!inductionVar || !lowerBound || !singleStep ||
156+
!llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
157+
LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
157158
return failure();
158159
}
159160

@@ -184,7 +185,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
184185

185186
// 3. Within the loop, build the modular leading index (i.e. each loop
186187
// iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
187-
rewriter.setInsertionPointToStart(&candidateLoop.getLoopBody().front());
188+
rewriter.setInsertionPointToStart(
189+
&candidateLoop.getLoopRegions().front()->front());
188190
Value ivVal = *inductionVar;
189191
Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
190192
Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
530530
return success();
531531
}
532532

533-
Region &ForOp::getLoopBody() { return getRegion(); }
533+
SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
534534

535535
ForOp mlir::scf::getForInductionVarOwner(Value val) {
536536
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
@@ -558,11 +558,11 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
558558
// Both the operation itself and the region may be branching into the body or
559559
// back into the operation itself. It is possible for loop not to enter the
560560
// body.
561-
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
561+
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
562562
regions.push_back(RegionSuccessor(getResults()));
563563
}
564564

565-
Region &ForallOp::getLoopBody() { return getRegion(); }
565+
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
566566

567567
/// Promotes the loop body of a forallOp to its containing block if it can be
568568
/// determined that the loop has a single iteration.
@@ -894,7 +894,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
894894
blockArgs.reserve(op.getInitArgs().size() + 1);
895895
blockArgs.push_back(op.getLowerBound());
896896
llvm::append_range(blockArgs, op.getInitArgs());
897-
replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
897+
replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
898898
return success();
899899
}
900900

@@ -2872,7 +2872,7 @@ void ParallelOp::print(OpAsmPrinter &p) {
28722872
/*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
28732873
}
28742874

2875-
Region &ParallelOp::getLoopBody() { return getRegion(); }
2875+
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
28762876

28772877
ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
28782878
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
@@ -2926,7 +2926,7 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
29262926
// loop body and nested ReduceOp's
29272927
SmallVector<Value> results;
29282928
results.reserve(op.getInitVals().size());
2929-
for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
2929+
for (auto &bodyOp : op.getBody()->without_terminator()) {
29302930
auto reduce = dyn_cast<ReduceOp>(bodyOp);
29312931
if (!reduce) {
29322932
rewriter.clone(bodyOp, mapping);
@@ -2965,7 +2965,7 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
29652965

29662966
LogicalResult matchAndRewrite(ParallelOp op,
29672967
PatternRewriter &rewriter) const override {
2968-
Block &outerBody = op.getLoopBody().front();
2968+
Block &outerBody = *op.getBody();
29692969
if (!llvm::hasSingleElement(outerBody.without_terminator()))
29702970
return failure();
29712971

@@ -2985,7 +2985,7 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
29852985

29862986
auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
29872987
ValueRange iterVals, ValueRange) {
2988-
Block &innerBody = innerOp.getLoopBody().front();
2988+
Block &innerBody = *innerOp.getBody();
29892989
assert(iterVals.size() ==
29902990
(outerBody.getNumArguments() + innerBody.getNumArguments()));
29912991
IRMapping mapping;
@@ -3203,6 +3203,10 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
32033203
regions.emplace_back(&getAfter(), getAfter().getArguments());
32043204
}
32053205

3206+
SmallVector<Region *> WhileOp::getLoopRegions() {
3207+
return {&getBefore(), &getAfter()};
3208+
}
3209+
32063210
/// Parses a `while` op.
32073211
///
32083212
/// op ::= `scf.while` assignments `:` function-type region `do` region

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,8 @@ struct ForOpInterface
3535

3636
// An EQ constraint can be added if the yielded value (dimension size)
3737
// equals the corresponding block argument (dimension size).
38-
assert(forOp.getLoopBody().hasOneBlock() &&
39-
"multiple blocks not supported");
40-
Value yieldedValue =
41-
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator())
42-
.getOperand(iterArgIdx);
38+
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
39+
.getOperand(iterArgIdx);
4340
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
4441
Value initArg = forOp.getInitArgs()[iterArgIdx];
4542

@@ -68,7 +65,7 @@ struct ForOpInterface
6865
// Stop when reaching a value that is defined outside of the loop. It
6966
// is impossible to reach an iter_arg from there.
7067
Operation *op = v.getDefiningOp();
71-
return forOp.getLoopBody().findAncestorOpInRegion(*op) == nullptr;
68+
return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
7269
});
7370
if (failed(status))
7471
return;

0 commit comments

Comments
 (0)