Skip to content

Commit e997779

Browse files
maerhartZijunZhaoCCK
authored andcommitted
[mlir][bufferization] Add an ownership based buffer deallocation pass (llvm#66337)
Add a new Buffer Deallocation pass with the intend to replace the old one. For now it is added as a separate pass alongside in order to allow downstream users to migrate over gradually. This new pass has the goal of inserting fewer clone operations and supporting additional use-cases. Please refer to the Buffer Deallocation section in the updated Bufferization.md file for more information on how this new pass works.
1 parent 58764b2 commit e997779

File tree

16 files changed

+4019
-0
lines changed

16 files changed

+4019
-0
lines changed

mlir/docs/Bufferization.md

Lines changed: 604 additions & 0 deletions
Large diffs are not rendered by default.

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ class BufferPlacementTransformationBase {
121121
Liveness liveness;
122122
};
123123

124+
/// Compare two SSA values in a deterministic manner. Two block arguments are
125+
/// ordered by argument number, block arguments are always less than operation
126+
/// results, and operation results are ordered by the `isBeforeInBlock` order of
127+
/// their defining operation.
128+
struct ValueComparator {
129+
bool operator()(const Value &lhs, const Value &rhs) const;
130+
};
131+
124132
// Create a global op for the given tensor-valued constant in the program.
125133
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
126134
// names. Duplicates are avoided.

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/Pass/Pass.h"
55

66
namespace mlir {
7+
class FunctionOpInterface;
78
class ModuleOp;
89
class RewritePatternSet;
910
class OpBuilder;
@@ -27,6 +28,10 @@ struct OneShotBufferizationOptions;
2728
/// buffers.
2829
std::unique_ptr<Pass> createBufferDeallocationPass();
2930

31+
/// Creates an instance of the OwnershipBasedBufferDeallocation pass to free all
32+
/// allocated buffers.
33+
std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass();
34+
3035
/// Creates a pass that optimizes `bufferization.dealloc` operations. For
3136
/// example, it reduces the number of alias checks needed at runtime using
3237
/// static alias analysis.
@@ -127,6 +132,10 @@ func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc,
127132
/// Run buffer deallocation.
128133
LogicalResult deallocateBuffers(Operation *op);
129134

135+
/// Run ownership basedbuffer deallocation.
136+
LogicalResult deallocateBuffersOwnershipBased(FunctionOpInterface op,
137+
bool privateFuncDynamicOwnership);
138+
130139
/// Creates a pass that moves allocations upwards to reduce the number of
131140
/// required copies that are inserted during the BufferDeallocation pass.
132141
std::unique_ptr<Pass> createBufferHoistingPass();

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,150 @@ def BufferDeallocation : Pass<"buffer-deallocation", "func::FuncOp"> {
8888
let constructor = "mlir::bufferization::createBufferDeallocationPass()";
8989
}
9090

91+
def OwnershipBasedBufferDeallocation : Pass<
92+
"ownership-based-buffer-deallocation", "func::FuncOp"> {
93+
let summary = "Adds all required dealloc operations for all allocations in "
94+
"the input program";
95+
let description = [{
96+
This pass implements an algorithm to automatically introduce all required
97+
deallocation operations for all buffers in the input program. This ensures
98+
that the resulting program does not have any memory leaks.
99+
100+
The Buffer Deallocation pass operates on the level of operations
101+
implementing the FunctionOpInterface. Such operations can take MemRefs as
102+
arguments, but also return them. To ensure compatibility among all functions
103+
(including external ones), some rules have to be enforced. They are just
104+
assumed to hold for all external functions. Functions for which the
105+
definition is available ideally also already adhere to the ABI.
106+
Otherwise, all MemRef write operations in the input IR must dominate all
107+
MemRef read operations in the input IR. Then, the pass may modify the input
108+
IR by inserting `bufferization.clone` operations such that the output IR
109+
adheres to the function boundary ABI:
110+
* When a MemRef is passed as a function argument, ownership is never
111+
acquired. It is always the caller's responsibility to deallocate such
112+
MemRefs.
113+
* Returning a MemRef from a function always passes ownership to the caller,
114+
i.e., it is also the caller's responsibility to deallocate MemRefs
115+
returned from a called function.
116+
* A function must not return a MemRef with the same allocated base buffer as
117+
one of its arguments (in this case a copy has to be created). Note that in
118+
this context two subviews of the same buffer that don't overlap are also
119+
considered an alias.
120+
121+
It is recommended to bufferize all operations first such that no tensor
122+
values remain in the IR once this pass is applied. That way all allocated
123+
MemRefs will be properly deallocated without any additional manual work.
124+
Otherwise, the pass that bufferizes the remaining tensors is responsible to
125+
add the corresponding deallocation operations. Note that this pass does not
126+
consider any values of tensor type and assumes that MemRef values defined by
127+
`bufferization.to_memref` do not return ownership and do not have to be
128+
deallocated. `bufferization.to_tensor` operations are handled similarly to
129+
`bufferization.clone` operations with the exception that the result value is
130+
not handled because it's a tensor (not a MemRef).
131+
132+
Input
133+
134+
```mlir
135+
#map0 = affine_map<(d0) -> (d0)>
136+
module {
137+
func.func @condBranch(%arg0: i1,
138+
%arg1: memref<2xf32>,
139+
%arg2: memref<2xf32>) {
140+
cf.cond_br %arg0, ^bb1, ^bb2
141+
^bb1:
142+
cf.br ^bb3(%arg1 : memref<2xf32>)
143+
^bb2:
144+
%0 = memref.alloc() : memref<2xf32>
145+
linalg.generic {
146+
args_in = 1 : i64,
147+
args_out = 1 : i64,
148+
indexing_maps = [#map0, #map0],
149+
iterator_types = ["parallel"]}
150+
outs(%arg1, %0 : memref<2xf32>, memref<2xf32>) {
151+
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
152+
%tmp1 = exp %gen1_arg0 : f32
153+
linalg.yield %tmp1 : f32
154+
}
155+
cf.br ^bb3(%0 : memref<2xf32>)
156+
^bb3(%1: memref<2xf32>):
157+
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
158+
return
159+
}
160+
}
161+
```
162+
163+
Output
164+
165+
```mlir
166+
#map = affine_map<(d0) -> (d0)>
167+
module {
168+
func.func @condBranch(%arg0: i1,
169+
%arg1: memref<2xf32>,
170+
%arg2: memref<2xf32>) {
171+
%false = arith.constant false
172+
%true = arith.constant true
173+
cf.cond_br %arg0, ^bb1, ^bb2
174+
^bb1: // pred: ^bb0
175+
cf.br ^bb3(%arg1, %false : memref<2xf32>, i1)
176+
^bb2: // pred: ^bb0
177+
%alloc = memref.alloc() : memref<2xf32>
178+
linalg.generic {
179+
indexing_maps = [#map, #map],
180+
iterator_types = ["parallel"]}
181+
outs(%arg1, %alloc : memref<2xf32>, memref<2xf32>)
182+
attrs = {args_in = 1 : i64, args_out = 1 : i64} {
183+
^bb0(%out: f32, %out_0: f32):
184+
%2 = math.exp %out : f32
185+
linalg.yield %2, %out_0 : f32, f32
186+
}
187+
cf.br ^bb3(%alloc, %true : memref<2xf32>, i1)
188+
^bb3(%0: memref<2xf32>, %1: i1): // 2 preds: ^bb1, ^bb2
189+
memref.copy %0, %arg2 : memref<2xf32> to memref<2xf32>
190+
%base_buffer, %offset, %sizes, %strides =
191+
memref.extract_strided_metadata %0 :
192+
memref<2xf32> -> memref<f32>, index, index, index
193+
bufferization.dealloc (%base_buffer : memref<f32>) if (%1)
194+
return
195+
}
196+
}
197+
```
198+
199+
The `private-function-dynamic-ownership` pass option allows the pass to add
200+
additional arguments to private functions to dynamically give ownership of
201+
MemRefs to callees. This can enable earlier deallocations and allows the
202+
pass to by-pass the function boundary ABI and thus potentially leading to
203+
fewer MemRef clones being inserted. For example, the private function
204+
```mlir
205+
func.func private @passthrough(%memref: memref<2xi32>) -> memref<2xi32> {
206+
return %memref : memref<2xi32>
207+
}
208+
```
209+
would be converted to
210+
```mlir
211+
func.func private @passthrough(%memref: memref<2xi32>,
212+
%ownership: i1) -> (memref<2xi32>, i1) {
213+
return %memref, %ownership : memref<2xi32>, i1
214+
}
215+
```
216+
and thus allows the returned MemRef to alias with the MemRef passed as
217+
argument (which would otherwise be forbidden according to the function
218+
boundary ABI).
219+
}];
220+
let options = [
221+
Option<"privateFuncDynamicOwnership", "private-function-dynamic-ownership",
222+
"bool", /*default=*/"false",
223+
"Allows to add additional arguments to private functions to "
224+
"dynamically pass ownership of memrefs to callees. This can enable "
225+
"earlier deallocations.">,
226+
];
227+
let constructor = "mlir::bufferization::createOwnershipBasedBufferDeallocationPass()";
228+
229+
let dependentDialects = [
230+
"mlir::bufferization::BufferizationDialect", "mlir::arith::ArithDialect",
231+
"mlir::memref::MemRefDialect", "mlir::scf::SCFDialect"
232+
];
233+
}
234+
91235
def BufferDeallocationSimplification :
92236
Pass<"buffer-deallocation-simplification", "func::FuncOp"> {
93237
let summary = "Optimizes `bufferization.dealloc` operation for more "

mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,62 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
202202
global->moveBefore(&moduleOp.front());
203203
return global;
204204
}
205+
206+
//===----------------------------------------------------------------------===//
207+
// ValueComparator
208+
//===----------------------------------------------------------------------===//
209+
210+
bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
211+
if (lhs == rhs)
212+
return false;
213+
214+
// Block arguments are less than results.
215+
bool lhsIsBBArg = lhs.isa<BlockArgument>();
216+
if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
217+
return lhsIsBBArg;
218+
}
219+
220+
Region *lhsRegion;
221+
Region *rhsRegion;
222+
if (lhsIsBBArg) {
223+
auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
224+
auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
225+
if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
226+
return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
227+
}
228+
lhsRegion = lhsBBArg.getParentRegion();
229+
rhsRegion = rhsBBArg.getParentRegion();
230+
assert(lhsRegion != rhsRegion &&
231+
"lhsRegion == rhsRegion implies lhs == rhs");
232+
} else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
233+
return llvm::cast<OpResult>(lhs).getResultNumber() <
234+
llvm::cast<OpResult>(rhs).getResultNumber();
235+
} else {
236+
lhsRegion = lhs.getDefiningOp()->getParentRegion();
237+
rhsRegion = rhs.getDefiningOp()->getParentRegion();
238+
if (lhsRegion == rhsRegion) {
239+
return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
240+
}
241+
}
242+
243+
// lhsRegion != rhsRegion, so if we look at their ancestor chain, they
244+
// - have different heights
245+
// - or there's a spot where their region numbers differ
246+
// - or their parent regions are the same and their parent ops are
247+
// different.
248+
while (lhsRegion && rhsRegion) {
249+
if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
250+
return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
251+
}
252+
if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
253+
return lhsRegion->getParentOp()->isBeforeInBlock(
254+
rhsRegion->getParentOp());
255+
}
256+
lhsRegion = lhsRegion->getParentRegion();
257+
rhsRegion = rhsRegion->getParentRegion();
258+
}
259+
if (rhsRegion)
260+
return true;
261+
assert(lhsRegion && "this should only happen if lhs == rhs");
262+
return false;
263+
}

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
1313
LowerDeallocations.cpp
1414
OneShotAnalysis.cpp
1515
OneShotModuleBufferize.cpp
16+
OwnershipBasedBufferDeallocation.cpp
1617
TensorCopyInsertion.cpp
1718

1819
ADDITIONAL_HEADER_DIRS
@@ -34,6 +35,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
3435
MLIRPass
3536
MLIRTensorDialect
3637
MLIRSCFDialect
38+
MLIRControlFlowDialect
3739
MLIRSideEffectInterfaces
3840
MLIRTransforms
3941
MLIRViewLikeInterface

0 commit comments

Comments
 (0)