-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir] add tensor_static.extract/insert to take only static indices. #110550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Yi Zhang (cathyzhyi) ChangesPatch is 23.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110550.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8fcc413edf2725 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -344,6 +344,43 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// ExtractStaticOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ExtractStaticOp : Tensor_Op<"extract_static", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ Pure,
+ TypesMatchWith<"result type matches element type of tensor",
+ "tensor", "result",
+ "::llvm::cast<TensorType>($_self).getElementType()">]> {
+ let summary = "element extraction operation with static indices";
+ let description = [{
+ The same as `tensor.extract` op except that `tensor.extract_static` op only
+ takes static indices.
+
+ Example:
+
+ ```mlir
+ %4 = tensor.extract_static %t[1, 2] : tensor<4x4xi32>
+ %5 = tensor.extract_static %rt[1, 2] : tensor<?x?xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ AnyRankedTensor:$tensor,
+ DenseI64ArrayAttr:$static_indices
+ );
+
+ let results = (outs AnyType:$result);
+ let assemblyFormat = [{$tensor `` $static_indices attr-dict `:` type($tensor)}];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
+
//===----------------------------------------------------------------------===//
// ExtractSliceOp
@@ -822,6 +859,50 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// InsertStaticOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_InsertStaticOp : Tensor_Op<"insert_static", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DestinationStyleOpInterface,
+ Pure,
+ TypesMatchWith<"result type matches type of dest",
+ "dest", "result",
+ "$_self">,
+ TypesMatchWith<"scalar type matches element type of dest",
+ "dest", "scalar",
+ "::llvm::cast<TensorType>($_self).getElementType()">]> {
+ let summary = "element insertion operation with static indices";
+ let description = [{
+ The same as `tensor.insert` op except that `tensor.insert_static` op only
+ takes static indices.
+
+ Example:
+
+ ```mlir
+ %4 = tensor.insert_static %t into %dest[1, 2] : tensor<4x4xi32>
+ %5 = tensor.insert_static %rt into %dest[1, 2] : tensor<?x?xi32>
+ ```
+ }];
+
+ let arguments = (ins AnyType:$scalar,
+ AnyRankedTensor:$dest,
+ DenseI64ArrayAttr:$static_indices);
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ $scalar `into` $dest `` $static_indices attr-dict `:` type($dest)
+ }];
+
+ let extraClassDeclaration = [{
+ MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+ }];
+
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
+
//===----------------------------------------------------------------------===//
// InsertSliceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1ac96756e22b5e..26d4434a484d61 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -39,6 +39,59 @@ using llvm::divideCeilSigned;
using llvm::divideFloorSigned;
using llvm::mod;
+namespace {
+template <typename ExtractOpTy>
+OpFoldResult foldExtractFromElementsHelper(ExtractOpTy op,
+ FromElementsOp fromElementsOp,
+ ArrayRef<uint64_t> indices) {
+ // Fold extract(from_elements(...)).
+ auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
+ auto rank = tensorType.getRank();
+ assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
+ "rank mismatch");
+ int flatIndex = 0;
+ int stride = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ flatIndex += indices[i] * stride;
+ stride *= tensorType.getDimSize(i);
+ }
+ // Prevent out of bounds accesses. This can happen in invalid code that
+ // will never execute.
+ if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
+ flatIndex < 0)
+ return {};
+ return fromElementsOp.getElements()[flatIndex];
+}
+
+LogicalResult verifyStaticIndicesInBound(RankedTensorType type,
+ ArrayRef<int64_t> indices) {
+ ArrayRef<int64_t> shape = type.getShape();
+ for (auto [dim, index] : llvm::zip(shape, indices)) {
+ if (index < 0)
+ return failure();
+ if (ShapedType::isDynamic(dim))
+ continue;
+ if (index >= dim)
+ return failure();
+ }
+ return success();
+}
+
+template <typename InsertOpTy, typename AdapterTy>
+OpFoldResult insertOpFoldHelper(InsertOpTy insert, AdapterTy adaptor) {
+ Attribute scalar = adaptor.getScalar();
+ Attribute dest = adaptor.getDest();
+ if (scalar && dest) {
+ if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest)) {
+ if (scalar == splatDest.getSplatValue<Attribute>())
+ return dest;
+ }
+ }
+ return {};
+}
+
+} // namespace
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
@@ -1097,18 +1150,28 @@ namespace {
/// to
///
/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
-struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+template <typename ExtractOpTy>
+struct ExtractFromTensorCast : public OpRewritePattern<ExtractOpTy> {
+ using OpRewritePattern<ExtractOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ LogicalResult matchAndRewrite(ExtractOpTy extract,
PatternRewriter &rewriter) const final {
- auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
+ auto tensorCast =
+ extract.getTensor().template getDefiningOp<tensor::CastOp>();
if (!tensorCast)
return failure();
if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
return failure();
- rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
- extract, tensorCast.getSource(), extract.getIndices());
+ Operation *op = extract;
+ if (auto extractOp = llvm::dyn_cast<tensor::ExtractOp>(op)) {
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ extractOp, tensorCast.getSource(), extractOp.getIndices());
+ } else if (auto extractStaticOp =
+ llvm::dyn_cast<tensor::ExtractStaticOp>(op)) {
+ rewriter.replaceOpWithNewOp<tensor::ExtractStaticOp>(
+ extractStaticOp, tensorCast.getSource(),
+ extractStaticOp.getStaticIndices());
+ }
return success();
}
};
@@ -1145,22 +1208,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold extract(from_elements(...)).
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
- auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
- auto rank = tensorType.getRank();
- assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
- "rank mismatch");
- int flatIndex = 0;
- int stride = 1;
- for (int i = rank - 1; i >= 0; --i) {
- flatIndex += indices[i] * stride;
- stride *= tensorType.getDimSize(i);
- }
- // Prevent out of bounds accesses. This can happen in invalid code that
- // will never execute.
- if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
- flatIndex < 0)
- return {};
- return fromElementsOp.getElements()[flatIndex];
+ return foldExtractFromElementsHelper<ExtractOp>(*this, fromElementsOp,
+ indices);
}
// If this is an elements attribute, query the value at the given indices.
@@ -1175,7 +1224,56 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractFromTensorCast>(context);
+ results.add<ExtractFromTensorCast<tensor::ExtractOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractStaticOp
+//===----------------------------------------------------------------------===//
+
+void ExtractStaticOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "extracted");
+}
+
+LogicalResult ExtractStaticOp::verify() {
+ // Verify the # indices match if we have a ranked type.
+ auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
+ if (tensorType.getRank() != static_cast<int64_t>(getStaticIndices().size()))
+ return emitOpError("incorrect number of indices for extract_static");
+ if (failed(verifyStaticIndicesInBound(tensorType, getStaticIndices())))
+ return emitOpError("static index out of bound for extract_static");
+ return success();
+}
+
+OpFoldResult ExtractStaticOp::fold(FoldAdaptor adaptor) {
+ // If this is a splat elements attribute, simply return the value. All of
+ // the elements of a splat attribute are the same.
+ if (Attribute tensor = adaptor.getTensor()) {
+ if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
+ return splatTensor.getSplatValue<Attribute>();
+ }
+
+ SmallVector<uint64_t, 8> indices(getStaticIndices());
+ // Fold extract(from_elements(...)).
+ if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
+ return foldExtractFromElementsHelper<ExtractStaticOp>(*this, fromElementsOp,
+ indices);
+ }
+
+ // If this is an elements attribute, query the value at the given indices.
+ if (Attribute tensor = adaptor.getTensor()) {
+ auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
+ if (elementsAttr && elementsAttr.isValidIndex(indices))
+ return elementsAttr.getValues<Attribute>()[indices];
+ }
+
+ return {};
+}
+
+void ExtractStaticOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ExtractFromTensorCast<tensor::ExtractStaticOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1368,13 +1466,34 @@ LogicalResult InsertOp::verify() {
}
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
- Attribute scalar = adaptor.getScalar();
- Attribute dest = adaptor.getDest();
- if (scalar && dest)
- if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
- if (scalar == splatDest.getSplatValue<Attribute>())
- return dest;
- return {};
+ return insertOpFoldHelper<InsertOp,
+ InsertOpGenericAdaptor<ArrayRef<Attribute>>>(
+ *this, adaptor);
+}
+
+//===----------------------------------------------------------------------===//
+// InsertStaticOp
+//===----------------------------------------------------------------------===//
+
+void InsertStaticOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "inserted");
+}
+
+LogicalResult InsertStaticOp::verify() {
+ // Verify the # indices match if we have a ranked type.
+ auto destType = llvm::cast<RankedTensorType>(getDest().getType());
+ if (destType.getRank() != static_cast<int64_t>(getStaticIndices().size()))
+ return emitOpError("incorrect number of indices for insert_static");
+ if (failed(verifyStaticIndicesInBound(destType, getStaticIndices())))
+ return emitOpError("static index out of bound for insert_static");
+ return success();
+}
+
+OpFoldResult InsertStaticOp::fold(FoldAdaptor adaptor) {
+ return insertOpFoldHelper<InsertStaticOp,
+ InsertStaticOpGenericAdaptor<ArrayRef<Attribute>>>(
+ *this, adaptor);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 86754c1c37536d..25b46e7877cbaa 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -173,6 +173,40 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
// -----
+// CHECK-LABEL: func @fold_extract_static
+func.func @fold_extract_static() -> (f32, f16, f16, i32, complex<f32>) {
+ // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
+ // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
+ // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
+
+ // Fold an extract into a splat.
+ // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32
+ %0 = arith.constant dense<4.0> : tensor<4xf32>
+ %ext_1 = tensor.extract_static %0[1] : tensor<4xf32>
+
+ // Fold an extract into a sparse with a sparse index.
+ %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16>
+ %ext_2 = tensor.extract_static %1[1, 1, 1] : tensor<4x4x4xf16>
+
+ // Fold an extract into a sparse with a non sparse index.
+ %2 = arith.constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16>
+ %ext_3 = tensor.extract_static %2[0, 0, 0] : tensor<2x2x2xf16>
+
+ // Fold an extract into a dense tensor.
+ %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
+ %ext_4 = tensor.extract_static %3[1, 0, 3] : tensor<2x1x4xi32>
+
+ // Fold an extract into a complex constant.
+ // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
+ %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
+ %ext_5 = tensor.extract_static %4[] : tensor<complex<f32>>
+
+ // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
+ return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_insert
func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
// Fold an insert into a splat.
@@ -186,6 +220,19 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
// -----
+// CHECK-LABEL: func @fold_insert_static
+func.func @fold_insert_static() -> (tensor<4xf32>) {
+ // Fold an insert into a splat.
+ // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32>
+ %0 = arith.constant dense<4.0> : tensor<4xf32>
+ %1 = arith.constant 4.0 : f32
+ %ins_1 = tensor.insert_static %1 into %0[3] : tensor<4xf32>
+ // CHECK-NEXT: return %[[C4]]
+ return %ins_1 : tensor<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.cast
// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
@@ -200,6 +247,18 @@ func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
// -----
+// CHECK-LABEL: func @extract_static_from_tensor.cast
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
+func.func @extract_static_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
+ // CHECK-NOT: tensor.cast
+ %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
+ // CHECK-NEXT: tensor.extract_static %[[TENSOR]][0]
+ %result = tensor.extract_static %casted[0] : tensor<?xf32>
+ return %result : f32
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.from_elements
func.func @extract_from_tensor.from_elements(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
@@ -212,6 +271,17 @@ func.func @extract_from_tensor.from_elements(%element : index) -> index {
// -----
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements
+func.func @extract_static_from_tensor.from_elements(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %tensor = tensor.from_elements %element : tensor<1xindex>
+ %extracted_element = tensor.extract_static %tensor[0] : tensor<1xindex>
+ // CHECK: [[ARG]] : index
+ return %extracted_element : index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.from_elements_0d
func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
@@ -224,6 +294,17 @@ func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
// -----
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements_0d
+func.func @extract_static_from_tensor.from_elements_0d(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %tensor = tensor.from_elements %element : tensor<index>
+ %extracted_element = tensor.extract_static %tensor[] : tensor<index>
+ // CHECK: [[ARG]] : index
+ return %extracted_element : index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.from_elements_3d
func.func @extract_from_tensor.from_elements_3d()
-> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
@@ -261,6 +342,61 @@ func.func @extract_from_tensor.from_elements_3d()
return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
: f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
}
+
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
+// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
+// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
+// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
+// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
+// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
+// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
+// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
+// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
+// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
+// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
+// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
+
+// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]],
+// CHECK-SAME: %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]]
+
+
+// -----
+
+// CHECK-LABEL: func @extract_static_from_tensor.from_elements_3d
+func.func @extract_static_from_tensor.from_elements_3d()
+ -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+ %f0 = arith.constant 0.0 : f32
+ %f1 = arith.constant 1.0 : f32
+ %f2 = arith.constant 2.0 : f32
+ %f3 = arith.constant 3.0 : f32
+ %f4 = arith.constant 4.0 : f32
+ %f5 = arith.constant 5.0 : f32
+ %f6 = arith.constant 6.0 : f32
+ %f7 = arith.constant 7.0 : f32
+ %f8 = arith.constant 8.0 : f32
+ %f9 = arith.constant 9.0 : f32
+ %f10 = arith.constant 10.0 : f32
+ %f11 = arith.constant 11.0 : f32
+
+ %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+ : tensor<3x2x2xf32>
+
+ %r0 = tensor.extract_static %tensor[0, 0, 0] : tensor<3x2x2xf32>
+ %r1 = tensor.extract_static %tensor[0, 0, 1] : tensor<3x2x2xf32>
+ %r2 = tensor.extract_static %tensor[0, 1, 0] : tensor<3x2x2xf32>
+ %r3 = tensor.extract_static %tensor[0, 1, 1] : tensor<3x2x2xf32>
+ %r4 = tensor.extract_static %tensor[1, 0, 0] : tensor<3x2x2xf32>
+ %r5 = tensor.extract_static %tensor[1, 0, 1] : tensor<3x2x2xf32>
+ %r6 = tensor.extract_static %tensor[1, 1, 0] : tensor<3x2x2xf32>
+ %r7 = tensor.extract_static %tensor[1, 1, 1] : tensor<3x2x2xf32>
+ %r8 = tensor.extract_static %tensor[2, 0, 0] : tensor<3x2x2xf32>
+ %r9 = tensor.extract_static %tensor[2, 0, 1] : tensor<3x2x2xf32>
+ %r10 = tensor.extract_static %tensor[2, 1, 0] : tensor<3x2x2xf32>
+ %r11 = tensor.extract_static %tensor[2, 1, 1] : tensor<3x2x2xf32>
+ return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
+ : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
+}
+
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 84e6c59e403dde..4be9b6a9c87183 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -72,6 +72,22 @@ f...
[truncated]
|
Thanks for the patch! Can you provide an extensive description to motivate why these are good additions here? It'd be nice to have some information on the tradeoffs and such. |
This seems like a very big change to me, because it opens up multiple ways of expressing operations on tensor elements and therefore changes the way patterns matching on |
@joker-eph @ubfx Thanks for the comments. Here are some more motivations behind this. We are experiencing inputs where there are O(40k) scalar constant operations whose only reason for existing is indexing. Its inefficient, its more difficult to test and cumbersome to write (using a little lazy cache per isolated from above context to avoid ending at O(400k) scalar constant operations that get CSE'd later). Its not special to indexing ops that one could capture constant values, but it is very wasteful for them (I forget, its 10x higher storage than just parameter). We've also discussed and experienced other options like using mixed values for indices in this PR and RFC but decide this is less change to the uniformity of the API. |
After reviewing the previous discussions, I agree with @joker-eph 's point that when choosing between SSA values and attributes as operands for an op, the decision should be based on whether it is necessary for lowering. Given this conclusion, splitting ops into dynamic and static variants won't solve our current problem: when selecting an op for building, we should choose based on whether the lowering pattern supports dynamic behavior, not whether the parameters are static. Furthermore, if we were to adopt this approach, should we also refactor other existing ops like memref.subview? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the in-depth discussion on this issue. Overall, the implementation of this patch looks good to me. However, before merging it, I'd like us to have a clearer discussion about how we should approach this kind of problem in general. Whether it's a set of definite rules or some temporary guidelines, it would be helpful for others to know how to proceed in similar situations.
} | ||
|
||
LogicalResult verifyStaticIndicesInBound(RankedTensorType type, | ||
ArrayRef<int64_t> indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should be marked as a static function.
|
||
void InsertStaticOp::getAsmResultNames( | ||
function_ref<void(Value, StringRef)> setNameFn) { | ||
setNameFn(getResult(), "inserted"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use "insert_static"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch. The implementation looks ok, but I think this needs more discussion about how to handle it with respect to the "non static" variants. Otherwise, this becomes an island and all optimizations/lowerings will end up duplicating.
We have other places that have been optimized in this way to allow mixed SSA and attribute indexing. It feels like that is where this needs to go vs having a static variant of all of the ops.
The
tensor.extract/insert_static
operation are introduced to handle cases where the indices are constant. Currently, the tensor.extract op requires dynamic handling of indices, even when they are constant, leading to unnecessary complexity in code writing, testing, and verification. These new ops are to help with readability, making the code more concise and easier to verify.