diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h index 46bfe717533a8..d6d3aeeb9bd05 100644 --- a/mlir/include/mlir/IR/RegionKindInterface.h +++ b/mlir/include/mlir/IR/RegionKindInterface.h @@ -43,6 +43,12 @@ class HasOnlyGraphRegion : public TraitBase { /// not implement the RegionKindInterface. bool mayHaveSSADominance(Region ®ion); +/// Return "true" if the given region may be a graph region without SSA +/// dominance. This function returns "true" in case the owner op is an +/// unregistered op. It returns "false" if it is a registered op that does not +/// implement the RegionKindInterface. +bool mayBeGraphRegion(Region ®ion); + } // namespace mlir #include "mlir/IR/RegionKindInterface.h.inc" diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index cad78b3e65b23..c34f422292cb4 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -394,12 +394,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { protected: void notifyOperationRemoved(Operation *op) override { - // TODO: Walk can be removed when D144193 has landed. - op->walk([&](Operation *op) { - erasedOps.insert(op); - // Erase if present. - toMemrefOps.erase(op); - }); + erasedOps.insert(op); + // Erase if present. + toMemrefOps.erase(op); } void notifyOperationInserted(Operation *op) override { diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index db920c14ea08d..5e9b9b2a810a4 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -8,6 +8,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Iterators.h" +#include "mlir/IR/RegionKindInterface.h" using namespace mlir; @@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { for (auto it : llvm::zip(op->getResults(), newValues)) replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - // Erase the op. + // Erase op and notify listener. eraseOp(op); } @@ -295,7 +297,7 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) { for (auto it : llvm::zip(op->getResults(), newOp->getResults())) replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - // Erase the old op. + // Erase op and notify listener. eraseOp(op); } @@ -303,9 +305,71 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) { /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); - if (auto *rewriteListener = dyn_cast_if_present(listener)) + auto *rewriteListener = dyn_cast_if_present(listener); + + // Fast path: If no listener is attached, the op can be dropped in one go. + if (!rewriteListener) { + op->erase(); + return; + } + + // Helper function that erases a single op. + auto eraseSingleOp = [&](Operation *op) { +#ifndef NDEBUG + // All nested ops should have been erased already. + assert( + llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) && + "expected empty regions"); + // All users should have been erased already if the op is in a region with + // SSA dominance. + if (!op->use_empty() && op->getParentOp()) + assert(mayBeGraphRegion(*op->getParentRegion()) && + "expected that op has no uses"); +#endif // NDEBUG rewriteListener->notifyOperationRemoved(op); - op->erase(); + + // Explicitly drop all uses in case the op is in a graph region. + op->dropAllUses(); + op->erase(); + }; + + // Nested ops must be erased one-by-one, so that listeners have a consistent + // view of the IR every time a notification is triggered. Users must be + // erased before definitions. I.e., post-order, reverse dominance. + std::function eraseTree = [&](Operation *op) { + // Erase nested ops. + for (Region &r : llvm::reverse(op->getRegions())) { + // Erase all blocks in the right order. Successors should be erased + // before predecessors because successor blocks may use values defined + // in predecessor blocks. A post-order traversal of blocks within a + // region visits successors before predecessors. Repeat the traversal + // until the region is empty. (The block graph could be disconnected.) + while (!r.empty()) { + SmallVector erasedBlocks; + for (Block *b : llvm::post_order(&r.front())) { + // Visit ops in reverse order. + for (Operation &op : + llvm::make_early_inc_range(ReverseIterator::makeIterable(*b))) + eraseTree(&op); + // Do not erase the block immediately. This is not supprted by the + // post_order iterator. + erasedBlocks.push_back(b); + } + for (Block *b : erasedBlocks) { + // Explicitly drop all uses in case there is a cycle in the block + // graph. + for (BlockArgument bbArg : b->getArguments()) + bbArg.dropAllUses(); + b->dropAllUses(); + b->erase(); + } + } + } + // Then erase the enclosing op. + eraseSingleOp(op); + }; + + eraseTree(op); } void RewriterBase::eraseBlock(Block *block) { diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp index cbef3025a5dd6..007f4cf92dbc7 100644 --- a/mlir/lib/IR/RegionKindInterface.cpp +++ b/mlir/lib/IR/RegionKindInterface.cpp @@ -18,9 +18,17 @@ using namespace mlir; #include "mlir/IR/RegionKindInterface.cpp.inc" bool mlir::mayHaveSSADominance(Region ®ion) { - auto regionKindOp = - dyn_cast_if_present(region.getParentOp()); + auto regionKindOp = dyn_cast(region.getParentOp()); if (!regionKindOp) return true; return regionKindOp.hasSSADominance(region.getRegionNumber()); } + +bool mlir::mayBeGraphRegion(Region ®ion) { + if (!region.getParentOp()->isRegistered()) + return true; + auto regionKindOp = dyn_cast(region.getParentOp()); + if (!regionKindOp) + return false; + return !regionKindOp.hasSSADominance(region.getRegionNumber()); +} diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index fba4944f130c2..8e2bfe557c555 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { // If the operation is trivially dead - remove it. if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); + eraseOp(op); changed = true; LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); @@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { config.listener->notifyOperationRemoved(op); addOperandsToWorklist(op->getOperands()); - op->walk([this](Operation *operation) { - worklist.remove(operation); - folder.notifyRemoval(operation); - }); + worklist.remove(op); + folder.notifyRemoval(op); if (config.strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir index 5df2d6d1fdeeb..a5ab8f97c74ce 100644 --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -12,9 +12,9 @@ // CHECK-EN-LABEL: func @test_erase // CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true} -// CHECK-EN: test.arg0 -// CHECK-EN: test.arg1 -// CHECK-EN-NOT: test.erase_op +// CHECK-EN: "test.arg0" +// CHECK-EN: "test.arg1" +// CHECK-EN-NOT: "test.erase_op" func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) @@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() { // CHECK-EN-LABEL: func @test_replace_with_erase_op // CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} -// CHECK-EN-NOT: test.replace_with_new_op -// CHECK-EN-NOT: test.erase_op +// CHECK-EN-NOT: "test.replace_with_new_op" +// CHECK-EN-NOT: "test.erase_op" // CHECK-EX-LABEL: func @test_replace_with_erase_op // CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} -// CHECK-EX-NOT: test.replace_with_new_op -// CHECK-EX: test.erase_op +// CHECK-EX-NOT: "test.replace_with_new_op" +// CHECK-EX: "test.erase_op" func.func @test_replace_with_erase_op() { "test.replace_with_new_op"() {create_erase_op} : () -> () return @@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() { // in turn, replaces the successor with bb3. "test.implicit_change_op"() [^bb1] : () -> () } + +// ----- + +// CHECK-AN: notifyOperationRemoved: test.foo_b +// CHECK-AN: notifyOperationRemoved: test.foo_a +// CHECK-AN: notifyOperationRemoved: test.graph_region +// CHECK-AN: notifyOperationRemoved: test.erase_op +// CHECK-AN-LABEL: func @test_remove_graph_region() +// CHECK-AN-NEXT: return +func.func @test_remove_graph_region() { + "test.erase_op"() ({ + test.graph_region { + %0 = "test.foo_a"(%1) : (i1) -> (i1) + %1 = "test.foo_b"(%0) : (i1) -> (i1) + } + }) : () -> () + return +} + +// ----- + +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.bar +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.foo +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.dummy_op +// CHECK-AN: notifyOperationRemoved: test.erase_op +// CHECK-AN-LABEL: func @test_remove_cyclic_blocks() +// CHECK-AN-NEXT: return +func.func @test_remove_cyclic_blocks() { + "test.erase_op"() ({ + %x = "test.dummy_op"() : () -> (i1) + cf.br ^bb1(%x: i1) + ^bb1(%arg0: i1): + "test.foo"(%x) : (i1) -> () + cf.br ^bb2(%arg0: i1) + ^bb2(%arg1: i1): + "test.bar"(%x) : (i1) -> () + cf.br ^bb1(%arg1: i1) + }) : () -> () + return +} + +// ----- + +// CHECK-AN: notifyOperationRemoved: test.dummy_op +// CHECK-AN: notifyOperationRemoved: test.bar +// CHECK-AN: notifyOperationRemoved: test.qux +// CHECK-AN: notifyOperationRemoved: test.qux_unreachable +// CHECK-AN: notifyOperationRemoved: test.nested_dummy +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.foo +// CHECK-AN: notifyOperationRemoved: test.erase_op +// CHECK-AN-LABEL: func @test_remove_dead_blocks() +// CHECK-AN-NEXT: return +func.func @test_remove_dead_blocks() { + "test.erase_op"() ({ + "test.dummy_op"() : () -> (i1) + // The following blocks are not reachable. Still, ^bb2 should be deleted + // befire ^bb1. + ^bb1(%arg0: i1): + "test.foo"() : () -> () + cf.br ^bb2(%arg0: i1) + ^bb2(%arg1: i1): + "test.nested_dummy"() ({ + "test.qux"() : () -> () + // The following block is unreachable. + ^bb3: + "test.qux_unreachable"() : () -> () + }) : () -> () + "test.bar"() : () -> () + }) : () -> () + return +} + +// ----- + +// test.nested_* must be deleted before test.foo. +// test.bar must be deleted before test.foo. + +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.bar +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.nested_b +// CHECK-AN: notifyOperationRemoved: test.nested_a +// CHECK-AN: notifyOperationRemoved: test.nested_d +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.nested_e +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.nested_c +// CHECK-AN: notifyOperationRemoved: test.foo +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.dummy_op +// CHECK-AN: notifyOperationRemoved: test.erase_op +// CHECK-AN-LABEL: func @test_remove_nested_ops() +// CHECK-AN-NEXT: return +func.func @test_remove_nested_ops() { + "test.erase_op"() ({ + %x = "test.dummy_op"() : () -> (i1) + cf.br ^bb1(%x: i1) + ^bb1(%arg0: i1): + "test.foo"() ({ + "test.nested_a"() : () -> () + "test.nested_b"() : () -> () + ^dead1: + "test.nested_c"() : () -> () + cf.br ^dead3 + ^dead2: + "test.nested_d"() : () -> () + ^dead3: + "test.nested_e"() : () -> () + cf.br ^dead2 + }) : () -> () + cf.br ^bb2(%arg0: i1) + ^bb2(%arg1: i1): + "test.bar"(%x) : (i1) -> () + cf.br ^bb1(%arg1: i1) + }) : () -> () + return +} + +// ----- + +// CHECK-AN: notifyOperationRemoved: test.qux +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.foo +// CHECK-AN: notifyOperationRemoved: cf.br +// CHECK-AN: notifyOperationRemoved: test.bar +// CHECK-AN: notifyOperationRemoved: cf.cond_br +// CHECK-AN-LABEL: func @test_remove_diamond( +// CHECK-AN-NEXT: return +func.func @test_remove_diamond(%c: i1) { + "test.erase_op"() ({ + cf.cond_br %c, ^bb1, ^bb2 + ^bb1: + "test.foo"() : () -> () + cf.br ^bb3 + ^bb2: + "test.bar"() : () -> () + cf.br ^bb3 + ^bb3: + "test.qux"() : () -> () + }) : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index e23ed105e3833..2e3bc76009ca2 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -239,6 +239,12 @@ struct TestPatternDriver llvm::cl::init(GreedyRewriteConfig().maxIterations)}; }; +struct DumpNotifications : public RewriterBase::Listener { + void notifyOperationRemoved(Operation *op) override { + llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n"; + } +}; + struct TestStrictPatternDriver : public PassWrapper> { public: @@ -275,7 +281,9 @@ struct TestStrictPatternDriver } }); + DumpNotifications dumpNotifications; GreedyRewriteConfig config; + config.listener = &dumpNotifications; if (strictMode == "AnyOp") { config.strictMode = GreedyRewriteStrictness::AnyOp; } else if (strictMode == "ExistingAndNewOps") {