Skip to content

[IR2Vec] Scale embeddings once in vocab analysis instead of repetitive scaling #143986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion llvm/docs/MLGO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,10 @@ downstream tasks, including ML-guided compiler optimizations.

The core components are:
- **Vocabulary**: A mapping from IR entities (opcodes, types, etc.) to their
vector representations. This is managed by ``IR2VecVocabAnalysis``.
vector representations. This is managed by ``IR2VecVocabAnalysis``. The
vocabulary (.json file) contains three sections -- Opcodes, Types, and
Arguments, each containing the representations of the corresponding
entities.
- **Embedder**: A class (``ir2vec::Embedder``) that uses the vocabulary to
compute embeddings for instructions, basic blocks, and functions.

Expand Down
16 changes: 15 additions & 1 deletion llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct Embedding {
/// Arithmetic operators
Embedding &operator+=(const Embedding &RHS);
Embedding &operator-=(const Embedding &RHS);
Embedding &operator*=(double Factor);

/// Adds Src Embedding scaled by Factor with the called Embedding.
/// Called_Embedding += Src * Factor
Expand All @@ -116,6 +117,8 @@ struct Embedding {
/// Returns true if the embedding is approximately equal to the RHS embedding
/// within the specified tolerance.
bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;

void print(raw_ostream &OS) const;
};

using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
Expand Down Expand Up @@ -234,6 +237,8 @@ class IR2VecVocabResult {
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
ir2vec::Vocab Vocabulary;
Error readVocabulary();
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
ir2vec::Vocab &TargetVocab, unsigned &Dim);
void emitError(Error Err, LLVMContext &Ctx);

public:
Expand All @@ -249,14 +254,23 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
/// functions.
class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
raw_ostream &OS;
void printVector(const ir2vec::Embedding &Vec) const;

public:
explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }
};

/// This pass prints the embeddings in the vocabulary
class IR2VecVocabPrinterPass : public PassInfoMixin<IR2VecVocabPrinterPass> {
raw_ostream &OS;

public:
explicit IR2VecVocabPrinterPass(raw_ostream &OS) : OS(OS) {}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }
};

} // namespace llvm

#endif // LLVM_ANALYSIS_IR2VEC_H
136 changes: 97 additions & 39 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
return *this;
}

Embedding &Embedding::operator*=(double Factor) {
std::transform(this->begin(), this->end(), this->begin(),
[Factor](double Elem) { return Elem * Factor; });
return *this;
}

Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
assert(this->size() == Src.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
Expand All @@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
return true;
}

void Embedding::print(raw_ostream &OS) const {
OS << " [";
for (const auto &Elem : Data)
OS << " " << format("%.2f", Elem) << " ";
OS << "]\n";
}

// ==----------------------------------------------------------------------===//
// Embedder and its subclasses
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
for (const auto &I : BB.instructionsWithoutDebug()) {
Embedding InstVector(Dimension, 0);

const auto OpcVec = lookupVocab(I.getOpcodeName());
InstVector.scaleAndAdd(OpcVec, OpcWeight);

// FIXME: Currently lookups are string based. Use numeric Keys
// for efficiency.
const auto Type = I.getType();
const auto TypeVec = getTypeEmbedding(Type);
InstVector.scaleAndAdd(TypeVec, TypeWeight);

InstVector += lookupVocab(I.getOpcodeName());
InstVector += getTypeEmbedding(I.getType());
for (const auto &Op : I.operands()) {
const auto OperandVec = getOperandEmbedding(Op.get());
InstVector.scaleAndAdd(OperandVec, ArgWeight);
InstVector += getOperandEmbedding(Op.get());
}
InstVecMap[&I] = InstVector;
BBVector += InstVector;
Expand Down Expand Up @@ -251,6 +258,43 @@ bool IR2VecVocabResult::invalidate(
return !(PAC.preservedWhenStateless());
}

Error IR2VecVocabAnalysis::parseVocabSection(
StringRef Key, const json::Value &ParsedVocabValue,
ir2vec::Vocab &TargetVocab, unsigned &Dim) {
json::Path::Root Path("");
const json::Object *RootObj = ParsedVocabValue.getAsObject();
if (!RootObj)
return createStringError(errc::invalid_argument,
"JSON root is not an object");

const json::Value *SectionValue = RootObj->get(Key);
if (!SectionValue)
return createStringError(errc::invalid_argument,
"Missing '" + std::string(Key) +
"' section in vocabulary file");
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
return createStringError(errc::illegal_byte_sequence,
"Unable to parse '" + std::string(Key) +
"' section from vocabulary");

Dim = TargetVocab.begin()->second.size();
if (Dim == 0)
return createStringError(errc::illegal_byte_sequence,
"Dimension of '" + std::string(Key) +
"' section of the vocabulary is zero");

if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
[Dim](const std::pair<StringRef, Embedding> &Entry) {
return Entry.second.size() == Dim;
}))
return createStringError(
errc::illegal_byte_sequence,
"All vectors in the '" + std::string(Key) +
"' section of the vocabulary are not of the same dimension");

return Error::success();
};

// FIXME: Make this optional. We can avoid file reads
// by auto-generating a default vocabulary during the build time.
Error IR2VecVocabAnalysis::readVocabulary() {
Expand All @@ -259,32 +303,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
return createFileError(VocabFile, BufOrError.getError());

auto Content = BufOrError.get()->getBuffer();
json::Path::Root Path("");

Expected<json::Value> ParsedVocabValue = json::parse(Content);
if (!ParsedVocabValue)
return ParsedVocabValue.takeError();

bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
if (!Res)
return createStringError(errc::illegal_byte_sequence,
"Unable to parse the vocabulary");
ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
if (auto Err = parseVocabSection("Opcodes", *ParsedVocabValue, OpcodeVocab,
Copy link
Member

@mtrofin mtrofin Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the format, best to also update the doc.

Also, this means the sections must all be present (in any order), even if empty, correct? SGTM, just something worth spelling out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Will put it in the doc.

OpcodeDim))
return Err;

if (Vocabulary.empty())
return createStringError(errc::illegal_byte_sequence,
"Vocabulary is empty");
if (auto Err =
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
return Err;

unsigned Dim = Vocabulary.begin()->second.size();
if (Dim == 0)
if (auto Err =
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
return Err;

if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
return createStringError(errc::illegal_byte_sequence,
"Dimension of vocabulary is zero");
"Vocabulary sections have different dimensions");

if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
[Dim](const std::pair<StringRef, Embedding> &Entry) {
return Entry.second.size() == Dim;
}))
return createStringError(
errc::illegal_byte_sequence,
"All vectors in the vocabulary are not of the same dimension");
auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
for (auto &Entry : Vocab)
Entry.second *= Weight;
};
scaleVocabSection(OpcodeVocab, OpcWeight);
scaleVocabSection(TypeVocab, TypeWeight);
scaleVocabSection(ArgVocab, ArgWeight);

Vocabulary.insert(OpcodeVocab.begin(), OpcodeVocab.end());
Vocabulary.insert(TypeVocab.begin(), TypeVocab.end());
Vocabulary.insert(ArgVocab.begin(), ArgVocab.end());

return Error::success();
}
Expand All @@ -304,7 +356,6 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
// If vocabulary is already populated by the constructor, use it.
if (!Vocabulary.empty())
return IR2VecVocabResult(std::move(Vocabulary));
Expand All @@ -323,16 +374,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
}

// ==----------------------------------------------------------------------===//
// IR2VecPrinterPass
// Printer Passes
//===----------------------------------------------------------------------===//

void IR2VecPrinterPass::printVector(const Embedding &Vec) const {
OS << " [";
for (const auto &Elem : Vec)
OS << " " << format("%.2f", Elem) << " ";
OS << "]\n";
}

PreservedAnalyses IR2VecPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
Expand All @@ -353,15 +397,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,

OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
printVector(Emb->getFunctionVector());
Emb->getFunctionVector().print(OS);

OS << "Basic block vectors:\n";
const auto &BBMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
auto It = BBMap.find(&BB);
if (It != BBMap.end()) {
OS << "Basic block: " << BB.getName() << ":\n";
printVector(It->second);
It->second.print(OS);
}
}

Expand All @@ -373,10 +417,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
if (It != InstMap.end()) {
OS << "Instruction: ";
I.print(OS);
printVector(It->second);
It->second.print(OS);
}
}
}
}
return PreservedAnalyses::all();
}

PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");

auto Vocab = IR2VecVocabResult.getVocabulary();
for (const auto &Entry : Vocab) {
OS << "Key: " << Entry.first << ": ";
Entry.second.print(OS);
}

return PreservedAnalyses::all();
}
Loading
Loading