Skip to content

Commit 88a01fd

Browse files
committed
MLIR-QUERY: backwardSlice, forwardSlice & QueryOptions added
1 parent b5df0e7 commit 88a01fd

File tree

17 files changed

+565
-58
lines changed

17 files changed

+565
-58
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ struct NameOpMatcher {
5959
NameOpMatcher(StringRef name) : name(name) {}
6060
bool match(Operation *op) { return op->getName().getStringRef() == name; }
6161

62-
StringRef name;
62+
std::string name;
6363
};
6464

6565
/// The matcher that matches operations that have the specified attribute name.
6666
struct AttrOpMatcher {
6767
AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
6868
bool match(Operation *op) { return op->hasAttr(attrName); }
6969

70-
StringRef attrName;
70+
std::string attrName;
7171
};
7272

7373
/// The matcher that matches operations that have the `ConstantLike` trait, and
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file provides extra matchers that are very useful for mlir-query
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_IR_EXTRAMATCHERS_H
15+
#define MLIR_IR_EXTRAMATCHERS_H
16+
17+
#include "MatchFinder.h"
18+
#include "MatchersInternal.h"
19+
#include "mlir/IR/Region.h"
20+
#include "mlir/Query/Query.h"
21+
#include "llvm/Support/raw_ostream.h"
22+
23+
namespace mlir {
24+
25+
namespace query {
26+
27+
namespace extramatcher {
28+
29+
namespace detail {
30+
31+
class BackwardSliceMatcher {
32+
public:
33+
BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
34+
: innerMatcher(std::move(innerMatcher)), hops(hops) {}
35+
36+
private:
37+
bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
38+
QueryOptions &options, unsigned tempHops) {
39+
40+
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
41+
return false;
42+
}
43+
44+
auto processValue = [&](Value value) {
45+
if (tempHops == 0) {
46+
return;
47+
}
48+
if (auto *definingOp = value.getDefiningOp()) {
49+
if (backwardSlice.count(definingOp) == 0)
50+
matches(definingOp, backwardSlice, options, tempHops - 1);
51+
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
52+
if (options.omitBlockArguments)
53+
return;
54+
Block *block = blockArg.getOwner();
55+
56+
Operation *parentOp = block->getParentOp();
57+
58+
if (parentOp && backwardSlice.count(parentOp) == 0) {
59+
assert(parentOp->getNumRegions() == 1 &&
60+
parentOp->getRegion(0).getBlocks().size() == 1);
61+
matches(parentOp, backwardSlice, options, tempHops-1);
62+
}
63+
} else {
64+
llvm_unreachable("No definingOp and not a block argument.");
65+
}
66+
};
67+
68+
if (!options.omitUsesFromAbove) {
69+
llvm::for_each(op->getRegions(), [&](Region &region) {
70+
SmallPtrSet<Region *, 4> descendents;
71+
region.walk(
72+
[&](Region *childRegion) { descendents.insert(childRegion); });
73+
region.walk([&](Operation *op) {
74+
for (OpOperand &operand : op->getOpOperands()) {
75+
if (!descendents.contains(operand.get().getParentRegion()))
76+
processValue(operand.get());
77+
}
78+
});
79+
});
80+
}
81+
82+
llvm::for_each(op->getOperands(), processValue);
83+
backwardSlice.insert(op);
84+
return true;
85+
}
86+
87+
public:
88+
bool match(Operation *op, SetVector<Operation *> &backwardSlice,
89+
QueryOptions &options) {
90+
if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
91+
if (!options.inclusive) {
92+
backwardSlice.remove(op);
93+
}
94+
return true;
95+
}
96+
return false;
97+
}
98+
99+
private:
100+
matcher::DynMatcher innerMatcher;
101+
unsigned hops;
102+
};
103+
104+
class ForwardSliceMatcher {
105+
public:
106+
ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
107+
: innerMatcher(std::move(innerMatcher)), hops(hops) {}
108+
109+
private:
110+
bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
111+
QueryOptions &options, unsigned tempHops) {
112+
113+
if (tempHops == 0) {
114+
forwardSlice.insert(op);
115+
return true;
116+
}
117+
118+
for (Region &region : op->getRegions())
119+
for (Block &block : region)
120+
for (Operation &blockOp : block)
121+
if (forwardSlice.count(&blockOp) == 0)
122+
matches(&blockOp, forwardSlice, options, tempHops - 1);
123+
for (Value result : op->getResults()) {
124+
for (Operation *userOp : result.getUsers())
125+
if (forwardSlice.count(userOp) == 0)
126+
matches(userOp, forwardSlice, options, tempHops - 1);
127+
}
128+
129+
forwardSlice.insert(op);
130+
return true;
131+
}
132+
133+
public:
134+
bool match(Operation *op, SetVector<Operation *> &forwardSlice,
135+
QueryOptions &options) {
136+
if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
137+
if (!options.inclusive) {
138+
forwardSlice.remove(op);
139+
}
140+
SmallVector<Operation *, 0> v(forwardSlice.takeVector());
141+
forwardSlice.insert(v.rbegin(), v.rend());
142+
return true;
143+
}
144+
return false;
145+
}
146+
147+
private:
148+
matcher::DynMatcher innerMatcher;
149+
unsigned hops;
150+
};
151+
152+
} // namespace detail
153+
154+
inline detail::BackwardSliceMatcher
155+
definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
156+
return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
157+
}
158+
159+
inline detail::BackwardSliceMatcher
160+
getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
161+
return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
162+
}
163+
164+
inline detail::ForwardSliceMatcher
165+
usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
166+
return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
167+
}
168+
169+
inline detail::ForwardSliceMatcher
170+
getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
171+
return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
172+
}
173+
174+
} // namespace extramatcher
175+
176+
} // namespace query
177+
178+
} // namespace mlir
179+
180+
#endif // MLIR_IR_EXTRAMATCHERS_H

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
5050
}
5151
};
5252

53+
template <>
54+
struct ArgTypeTraits<unsigned> {
55+
static bool hasCorrectType(const VariantValue &value) {
56+
return value.isUnsigned();
57+
}
58+
59+
static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
60+
61+
static ArgKind getKind() { return ArgKind::Unsigned; }
62+
63+
static std::optional<std::string> getBestGuess(const VariantValue &) {
64+
return std::nullopt;
65+
}
66+
};
67+
5368
template <>
5469
struct ArgTypeTraits<DynMatcher> {
5570

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,60 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file contains the MatchFinder class, which is used to find operations
10-
// that match a given matcher.
10+
// that match a given matcher and print them.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

1414
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1515
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1616

1717
#include "MatchersInternal.h"
18+
#include "mlir/Query/QuerySession.h"
19+
#include "llvm/ADT/SetVector.h"
20+
#include "llvm/Support/SourceMgr.h"
21+
#include "llvm/Support/raw_ostream.h"
1822

1923
namespace mlir::query::matcher {
2024

21-
// MatchFinder is used to find all operations that match a given matcher.
2225
class MatchFinder {
26+
private:
27+
static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
28+
mlir::Operation *op, const std::string &binding) {
29+
auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
30+
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
31+
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
32+
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
33+
"\"" + binding + "\" binds here");
34+
};
35+
2336
public:
24-
// Returns all operations that match the given matcher.
25-
static std::vector<Operation *> getMatches(Operation *root,
26-
DynMatcher matcher) {
27-
std::vector<Operation *> matches;
37+
static std::vector<Operation *>
38+
getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
39+
llvm::raw_ostream &os, QuerySession &qs) {
40+
unsigned matchCount = 0;
41+
std::vector<Operation *> matchedOps;
42+
SetVector<Operation *> tempStorage;
2843

29-
// Simple match finding with walk.
3044
root->walk([&](Operation *subOp) {
31-
if (matcher.match(subOp))
32-
matches.push_back(subOp);
33-
});
45+
if (matcher.match(subOp)) {
46+
matchedOps.push_back(subOp);
47+
os << "Match #" << ++matchCount << ":\n\n";
48+
printMatch(os, qs, subOp, "root");
49+
} else {
50+
SmallVector<Operation *> printingOps;
3451

35-
return matches;
52+
if (matcher.match(subOp, tempStorage, options)) {
53+
os << "Match #" << ++matchCount << ":\n\n";
54+
SmallVector<Operation *> printingOps(tempStorage.takeVector());
55+
for (auto op : printingOps) {
56+
printMatch(os, qs, op, "root");
57+
matchedOps.push_back(op);
58+
}
59+
printingOps.clear();
60+
}
61+
}
62+
});
63+
return matchedOps;
3664
}
3765
};
3866

0 commit comments

Comments
 (0)