Skip to content

[mlir] add tensor_static.extract/insert to take only static indices. #110550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,43 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// ExtractStaticOp
//===----------------------------------------------------------------------===//

def Tensor_ExtractStaticOp : Tensor_Op<"extract_static", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
TypesMatchWith<"result type matches element type of tensor",
"tensor", "result",
"::llvm::cast<TensorType>($_self).getElementType()">]> {
let summary = "element extraction operation with static indices";
let description = [{
The same as `tensor.extract` op except that `tensor.extract_static` op only
takes static indices.

Example:

```mlir
%4 = tensor.extract_static %t[1, 2] : tensor<4x4xi32>
%5 = tensor.extract_static %rt[1, 2] : tensor<?x?xi32>
```
}];

let arguments = (ins
AnyRankedTensor:$tensor,
DenseI64ArrayAttr:$static_indices
);

let results = (outs AnyType:$result);
let assemblyFormat = [{$tensor `` $static_indices attr-dict `:` type($tensor)}];

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}



//===----------------------------------------------------------------------===//
// ExtractSliceOp
Expand Down Expand Up @@ -822,6 +859,50 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// InsertStaticOp
//===----------------------------------------------------------------------===//

def Tensor_InsertStaticOp : Tensor_Op<"insert_static", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface,
Pure,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
"$_self">,
TypesMatchWith<"scalar type matches element type of dest",
"dest", "scalar",
"::llvm::cast<TensorType>($_self).getElementType()">]> {
let summary = "element insertion operation with static indices";
let description = [{
The same as `tensor.insert` op except that `tensor.insert_static` op only
takes static indices.

Example:

```mlir
%4 = tensor.insert_static %t into %dest[1, 2] : tensor<4x4xi32>
%5 = tensor.insert_static %rt into %dest[1, 2] : tensor<?x?xi32>
```
}];

let arguments = (ins AnyType:$scalar,
AnyRankedTensor:$dest,
DenseI64ArrayAttr:$static_indices);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$scalar `into` $dest `` $static_indices attr-dict `:` type($dest)
}];

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];

let hasFolder = 1;
let hasVerifier = 1;
}


//===----------------------------------------------------------------------===//
// InsertSliceOp
//===----------------------------------------------------------------------===//
Expand Down
179 changes: 149 additions & 30 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,59 @@ using llvm::divideCeilSigned;
using llvm::divideFloorSigned;
using llvm::mod;

namespace {
template <typename ExtractOpTy>
OpFoldResult foldExtractFromElementsHelper(ExtractOpTy op,
FromElementsOp fromElementsOp,
ArrayRef<uint64_t> indices) {
// Fold extract(from_elements(...)).
auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
auto rank = tensorType.getRank();
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
"rank mismatch");
int flatIndex = 0;
int stride = 1;
for (int i = rank - 1; i >= 0; --i) {
flatIndex += indices[i] * stride;
stride *= tensorType.getDimSize(i);
}
// Prevent out of bounds accesses. This can happen in invalid code that
// will never execute.
if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
flatIndex < 0)
return {};
return fromElementsOp.getElements()[flatIndex];
}

LogicalResult verifyStaticIndicesInBound(RankedTensorType type,
ArrayRef<int64_t> indices) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be marked as a static function.

ArrayRef<int64_t> shape = type.getShape();
for (auto [dim, index] : llvm::zip(shape, indices)) {
if (index < 0)
return failure();
if (ShapedType::isDynamic(dim))
continue;
if (index >= dim)
return failure();
}
return success();
}

template <typename InsertOpTy, typename AdapterTy>
OpFoldResult insertOpFoldHelper(InsertOpTy insert, AdapterTy adaptor) {
Attribute scalar = adaptor.getScalar();
Attribute dest = adaptor.getDest();
if (scalar && dest) {
if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest)) {
if (scalar == splatDest.getSplatValue<Attribute>())
return dest;
}
}
return {};
}

} // namespace

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
Expand Down Expand Up @@ -1097,18 +1150,28 @@ namespace {
/// to
///
/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
template <typename ExtractOpTy>
struct ExtractFromTensorCast : public OpRewritePattern<ExtractOpTy> {
using OpRewritePattern<ExtractOpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractOp extract,
LogicalResult matchAndRewrite(ExtractOpTy extract,
PatternRewriter &rewriter) const final {
auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
auto tensorCast =
extract.getTensor().template getDefiningOp<tensor::CastOp>();
if (!tensorCast)
return failure();
if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
return failure();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extract, tensorCast.getSource(), extract.getIndices());
Operation *op = extract;
if (auto extractOp = llvm::dyn_cast<tensor::ExtractOp>(op)) {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, tensorCast.getSource(), extractOp.getIndices());
} else if (auto extractStaticOp =
llvm::dyn_cast<tensor::ExtractStaticOp>(op)) {
rewriter.replaceOpWithNewOp<tensor::ExtractStaticOp>(
extractStaticOp, tensorCast.getSource(),
extractStaticOp.getStaticIndices());
}
return success();
}
};
Expand Down Expand Up @@ -1145,22 +1208,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {

// Fold extract(from_elements(...)).
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
auto rank = tensorType.getRank();
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
"rank mismatch");
int flatIndex = 0;
int stride = 1;
for (int i = rank - 1; i >= 0; --i) {
flatIndex += indices[i] * stride;
stride *= tensorType.getDimSize(i);
}
// Prevent out of bounds accesses. This can happen in invalid code that
// will never execute.
if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
flatIndex < 0)
return {};
return fromElementsOp.getElements()[flatIndex];
return foldExtractFromElementsHelper<ExtractOp>(*this, fromElementsOp,
indices);
}

// If this is an elements attribute, query the value at the given indices.
Expand All @@ -1175,7 +1224,56 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractFromTensorCast>(context);
results.add<ExtractFromTensorCast<tensor::ExtractOp>>(context);
}

//===----------------------------------------------------------------------===//
// ExtractStaticOp
//===----------------------------------------------------------------------===//

void ExtractStaticOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "extracted");
}

LogicalResult ExtractStaticOp::verify() {
// Verify the # indices match if we have a ranked type.
auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
if (tensorType.getRank() != static_cast<int64_t>(getStaticIndices().size()))
return emitOpError("incorrect number of indices for extract_static");
if (failed(verifyStaticIndicesInBound(tensorType, getStaticIndices())))
return emitOpError("static index out of bound for extract_static");
return success();
}

OpFoldResult ExtractStaticOp::fold(FoldAdaptor adaptor) {
// If this is a splat elements attribute, simply return the value. All of
// the elements of a splat attribute are the same.
if (Attribute tensor = adaptor.getTensor()) {
if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
return splatTensor.getSplatValue<Attribute>();
}

SmallVector<uint64_t, 8> indices(getStaticIndices());
// Fold extract(from_elements(...)).
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
return foldExtractFromElementsHelper<ExtractStaticOp>(*this, fromElementsOp,
indices);
}

// If this is an elements attribute, query the value at the given indices.
if (Attribute tensor = adaptor.getTensor()) {
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValues<Attribute>()[indices];
}

return {};
}

void ExtractStaticOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractFromTensorCast<tensor::ExtractStaticOp>>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1368,13 +1466,34 @@ LogicalResult InsertOp::verify() {
}

OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
Attribute scalar = adaptor.getScalar();
Attribute dest = adaptor.getDest();
if (scalar && dest)
if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
if (scalar == splatDest.getSplatValue<Attribute>())
return dest;
return {};
return insertOpFoldHelper<InsertOp,
InsertOpGenericAdaptor<ArrayRef<Attribute>>>(
*this, adaptor);
}

//===----------------------------------------------------------------------===//
// InsertStaticOp
//===----------------------------------------------------------------------===//

void InsertStaticOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use "insert_static"?

}

LogicalResult InsertStaticOp::verify() {
// Verify the # indices match if we have a ranked type.
auto destType = llvm::cast<RankedTensorType>(getDest().getType());
if (destType.getRank() != static_cast<int64_t>(getStaticIndices().size()))
return emitOpError("incorrect number of indices for insert_static");
if (failed(verifyStaticIndicesInBound(destType, getStaticIndices())))
return emitOpError("static index out of bound for insert_static");
return success();
}

OpFoldResult InsertStaticOp::fold(FoldAdaptor adaptor) {
return insertOpFoldHelper<InsertStaticOp,
InsertStaticOpGenericAdaptor<ArrayRef<Attribute>>>(
*this, adaptor);
}

//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading