Skip to content

Commit 32d16aa

Browse files
committed
[IR2Vec] Scale vocab
1 parent 9124e83 commit 32d16aa

File tree

17 files changed

+374
-152
lines changed

17 files changed

+374
-152
lines changed

llvm/docs/MLGO.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,10 @@ downstream tasks, including ML-guided compiler optimizations.
448448

449449
The core components are:
450450
- **Vocabulary**: A mapping from IR entities (opcodes, types, etc.) to their
451-
vector representations. This is managed by ``IR2VecVocabAnalysis``.
451+
vector representations. This is managed by ``IR2VecVocabAnalysis``. The
452+
vocabulary (.json file) contains three sections -- Opcodes, Types, and
453+
Arguments, each containing the representations of the corresponding
454+
entities.
452455
- **Embedder**: A class (``ir2vec::Embedder``) that uses the vocabulary to
453456
compute embeddings for instructions, basic blocks, and functions.
454457

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct Embedding {
108108
/// Arithmetic operators
109109
Embedding &operator+=(const Embedding &RHS);
110110
Embedding &operator-=(const Embedding &RHS);
111+
Embedding &operator*=(double Factor);
111112

112113
/// Adds Src Embedding scaled by Factor with the called Embedding.
113114
/// Called_Embedding += Src * Factor
@@ -116,6 +117,8 @@ struct Embedding {
116117
/// Returns true if the embedding is approximately equal to the RHS embedding
117118
/// within the specified tolerance.
118119
bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
120+
121+
void print(raw_ostream &OS) const;
119122
};
120123

121124
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
@@ -234,6 +237,8 @@ class IR2VecVocabResult {
234237
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
235238
ir2vec::Vocab Vocabulary;
236239
Error readVocabulary();
240+
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
241+
ir2vec::Vocab &TargetVocab, unsigned &Dim);
237242
void emitError(Error Err, LLVMContext &Ctx);
238243

239244
public:
@@ -249,14 +254,23 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
249254
/// functions.
250255
class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
251256
raw_ostream &OS;
252-
void printVector(const ir2vec::Embedding &Vec) const;
253257

254258
public:
255259
explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
256260
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
257261
static bool isRequired() { return true; }
258262
};
259263

264+
/// This pass prints the embeddings in the vocabulary
265+
class IR2VecVocabPrinterPass : public PassInfoMixin<IR2VecVocabPrinterPass> {
266+
raw_ostream &OS;
267+
268+
public:
269+
explicit IR2VecVocabPrinterPass(raw_ostream &OS) : OS(OS) {}
270+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
271+
static bool isRequired() { return true; }
272+
};
273+
260274
} // namespace llvm
261275

262276
#endif // LLVM_ANALYSIS_IR2VEC_H

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 97 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
8585
return *this;
8686
}
8787

88+
Embedding &Embedding::operator*=(double Factor) {
89+
std::transform(this->begin(), this->end(), this->begin(),
90+
[Factor](double Elem) { return Elem * Factor; });
91+
return *this;
92+
}
93+
8894
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
8995
assert(this->size() == Src.size() && "Vectors must have the same dimension");
9096
for (size_t Itr = 0; Itr < this->size(); ++Itr)
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
101107
return true;
102108
}
103109

110+
void Embedding::print(raw_ostream &OS) const {
111+
OS << " [";
112+
for (const auto &Elem : Data)
113+
OS << " " << format("%.2f", Elem) << " ";
114+
OS << "]\n";
115+
}
116+
104117
// ==----------------------------------------------------------------------===//
105118
// Embedder and its subclasses
106119
//===----------------------------------------------------------------------===//
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196209
for (const auto &I : BB.instructionsWithoutDebug()) {
197210
Embedding InstVector(Dimension, 0);
198211

199-
const auto OpcVec = lookupVocab(I.getOpcodeName());
200-
InstVector.scaleAndAdd(OpcVec, OpcWeight);
201-
202212
// FIXME: Currently lookups are string based. Use numeric Keys
203213
// for efficiency.
204-
const auto Type = I.getType();
205-
const auto TypeVec = getTypeEmbedding(Type);
206-
InstVector.scaleAndAdd(TypeVec, TypeWeight);
207-
214+
InstVector += lookupVocab(I.getOpcodeName());
215+
InstVector += getTypeEmbedding(I.getType());
208216
for (const auto &Op : I.operands()) {
209-
const auto OperandVec = getOperandEmbedding(Op.get());
210-
InstVector.scaleAndAdd(OperandVec, ArgWeight);
217+
InstVector += getOperandEmbedding(Op.get());
211218
}
212219
InstVecMap[&I] = InstVector;
213220
BBVector += InstVector;
@@ -251,6 +258,43 @@ bool IR2VecVocabResult::invalidate(
251258
return !(PAC.preservedWhenStateless());
252259
}
253260

261+
Error IR2VecVocabAnalysis::parseVocabSection(
262+
StringRef Key, const json::Value &ParsedVocabValue,
263+
ir2vec::Vocab &TargetVocab, unsigned &Dim) {
264+
json::Path::Root Path("");
265+
const json::Object *RootObj = ParsedVocabValue.getAsObject();
266+
if (!RootObj)
267+
return createStringError(errc::invalid_argument,
268+
"JSON root is not an object");
269+
270+
const json::Value *SectionValue = RootObj->get(Key);
271+
if (!SectionValue)
272+
return createStringError(errc::invalid_argument,
273+
"Missing '" + std::string(Key) +
274+
"' section in vocabulary file");
275+
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
276+
return createStringError(errc::illegal_byte_sequence,
277+
"Unable to parse '" + std::string(Key) +
278+
"' section from vocabulary");
279+
280+
Dim = TargetVocab.begin()->second.size();
281+
if (Dim == 0)
282+
return createStringError(errc::illegal_byte_sequence,
283+
"Dimension of '" + std::string(Key) +
284+
"' section of the vocabulary is zero");
285+
286+
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
287+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
288+
return Entry.second.size() == Dim;
289+
}))
290+
return createStringError(
291+
errc::illegal_byte_sequence,
292+
"All vectors in the '" + std::string(Key) +
293+
"' section of the vocabulary are not of the same dimension");
294+
295+
return Error::success();
296+
};
297+
254298
// FIXME: Make this optional. We can avoid file reads
255299
// by auto-generating a default vocabulary during the build time.
256300
Error IR2VecVocabAnalysis::readVocabulary() {
@@ -259,32 +303,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
259303
return createFileError(VocabFile, BufOrError.getError());
260304

261305
auto Content = BufOrError.get()->getBuffer();
262-
json::Path::Root Path("");
306+
263307
Expected<json::Value> ParsedVocabValue = json::parse(Content);
264308
if (!ParsedVocabValue)
265309
return ParsedVocabValue.takeError();
266310

267-
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
268-
if (!Res)
269-
return createStringError(errc::illegal_byte_sequence,
270-
"Unable to parse the vocabulary");
311+
ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
312+
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
313+
if (auto Err = parseVocabSection("Opcodes", *ParsedVocabValue, OpcodeVocab,
314+
OpcodeDim))
315+
return Err;
271316

272-
if (Vocabulary.empty())
273-
return createStringError(errc::illegal_byte_sequence,
274-
"Vocabulary is empty");
317+
if (auto Err =
318+
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
319+
return Err;
275320

276-
unsigned Dim = Vocabulary.begin()->second.size();
277-
if (Dim == 0)
321+
if (auto Err =
322+
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
323+
return Err;
324+
325+
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
278326
return createStringError(errc::illegal_byte_sequence,
279-
"Dimension of vocabulary is zero");
327+
"Vocabulary sections have different dimensions");
280328

281-
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
282-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
283-
return Entry.second.size() == Dim;
284-
}))
285-
return createStringError(
286-
errc::illegal_byte_sequence,
287-
"All vectors in the vocabulary are not of the same dimension");
329+
auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
330+
for (auto &Entry : Vocab)
331+
Entry.second *= Weight;
332+
};
333+
scaleVocabSection(OpcodeVocab, OpcWeight);
334+
scaleVocabSection(TypeVocab, TypeWeight);
335+
scaleVocabSection(ArgVocab, ArgWeight);
336+
337+
Vocabulary.insert(OpcodeVocab.begin(), OpcodeVocab.end());
338+
Vocabulary.insert(TypeVocab.begin(), TypeVocab.end());
339+
Vocabulary.insert(ArgVocab.begin(), ArgVocab.end());
288340

289341
return Error::success();
290342
}
@@ -304,7 +356,6 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
304356
IR2VecVocabAnalysis::Result
305357
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
306358
auto Ctx = &M.getContext();
307-
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
308359
// If vocabulary is already populated by the constructor, use it.
309360
if (!Vocabulary.empty())
310361
return IR2VecVocabResult(std::move(Vocabulary));
@@ -323,16 +374,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
323374
}
324375

325376
// ==----------------------------------------------------------------------===//
326-
// IR2VecPrinterPass
377+
// Printer Passes
327378
//===----------------------------------------------------------------------===//
328379

329-
void IR2VecPrinterPass::printVector(const Embedding &Vec) const {
330-
OS << " [";
331-
for (const auto &Elem : Vec)
332-
OS << " " << format("%.2f", Elem) << " ";
333-
OS << "]\n";
334-
}
335-
336380
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
337381
ModuleAnalysisManager &MAM) {
338382
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
@@ -353,15 +397,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
353397

354398
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
355399
OS << "Function vector: ";
356-
printVector(Emb->getFunctionVector());
400+
Emb->getFunctionVector().print(OS);
357401

358402
OS << "Basic block vectors:\n";
359403
const auto &BBMap = Emb->getBBVecMap();
360404
for (const BasicBlock &BB : F) {
361405
auto It = BBMap.find(&BB);
362406
if (It != BBMap.end()) {
363407
OS << "Basic block: " << BB.getName() << ":\n";
364-
printVector(It->second);
408+
It->second.print(OS);
365409
}
366410
}
367411

@@ -373,10 +417,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
373417
if (It != InstMap.end()) {
374418
OS << "Instruction: ";
375419
I.print(OS);
376-
printVector(It->second);
420+
It->second.print(OS);
377421
}
378422
}
379423
}
380424
}
381425
return PreservedAnalyses::all();
382426
}
427+
428+
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
429+
ModuleAnalysisManager &MAM) {
430+
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
431+
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
432+
433+
auto Vocab = IR2VecVocabResult.getVocabulary();
434+
for (const auto &Entry : Vocab) {
435+
OS << "Key: " << Entry.first << ": ";
436+
Entry.second.print(OS);
437+
}
438+
439+
return PreservedAnalyses::all();
440+
}

0 commit comments

Comments
 (0)