-
Notifications
You must be signed in to change notification settings - Fork 14k
[IR2Vec] Simplifying creation of Embedder #143999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/svkeerthy/06-12-_ir2vec_scale_vocab
Are you sure you want to change the base?
[IR2Vec] Simplifying creation of Embedder #143999
Conversation
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-mlgo @llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesThis change simplifies the API by removing the error handling complexity.
(Tracking issue - #141817) Full diff: https://github.com/llvm/llvm-project/pull/143999.diff 6 Files Affected:
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 4f8fb3f59ca19..e7bba9995b75b 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -479,14 +479,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
// Assuming F is an llvm::Function&
// For example, using IR2VecKind::Symbolic:
- Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
+ std::unique_ptr<ir2vec::Embedder> Emb =
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
- if (auto Err = EmbOrErr.takeError()) {
- // Handle error in embedder creation
- return;
- }
- std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
3. **Compute and Access Embeddings**:
Call ``getFunctionVector()`` to get the embedding for the function.
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index f1aaf4cd2e013..6efa6eac56af9 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,8 +170,8 @@ class Embedder {
virtual ~Embedder() = default;
/// Factory method to create an Embedder object.
- static Expected<std::unique_ptr<Embedder>>
- create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
+ static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
+ const Vocab &Vocabulary);
/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 29d3aaf46dc06..dd4eb7f0df053 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
// We instantiate the IR2Vec embedder each time, as having an unique
// pointer to the embedder as member of the class would make it
// non-copyable. Instantiating the embedder in itself is not costly.
- auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+ auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
*BB.getParent(), *IR2VecVocab);
- if (Error Err = EmbOrErr.takeError()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- BB.getContext().emitError("Error creating IR2Vec embeddings: " +
- EI.message());
- });
+ if (!Embedder) {
+ BB.getContext().emitError("Error creating IR2Vec embeddings");
return;
}
- auto Embedder = std::move(*EmbOrErr);
const auto &BBEmbedding = Embedder->getBBVector(BB);
// Subtract BBEmbedding from Function embedding if the direction is -1,
// and add it if the direction is +1.
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index f51d3252d6606..68026618449d8 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
-Expected<std::unique_ptr<Embedder>>
-Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
+std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
+ const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
- return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
+ llvm_unreachable("Unknown IR2Vec kind");
+ return nullptr;
}
// FIXME: Currently lookups are string based. Use numeric Keys
@@ -389,17 +390,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
auto Vocab = IR2VecVocabResult.getVocabulary();
for (Function &F : M) {
- Expected<std::unique_ptr<Embedder>> EmbOrErr =
+ std::unique_ptr<Embedder> Emb =
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
- if (auto Err = EmbOrErr.takeError()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
- });
+ if (!Emb) {
+ OS << "Error creating IR2Vec embeddings \n";
continue;
}
- std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
Emb->getFunctionVector().print(OS);
@@ -442,4 +439,4 @@ PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
}
return PreservedAnalyses::all();
-}
\ No newline at end of file
+}
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index e50486bcbcb27..ca4f5d0f63026 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
}
std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
- auto EmbResult =
- ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
- EXPECT_TRUE(static_cast<bool>(EmbResult));
- return std::move(*EmbResult);
+ auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ EXPECT_TRUE(static_cast<bool>(Emb));
+ return std::move(Emb);
}
};
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index c3ed6e90cd8fc..05af55b59323b 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- EXPECT_TRUE(static_cast<bool>(Result));
-
- auto *Emb = Result->get();
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_NE(Emb, nullptr);
}
@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- // static_cast an invalid int to IR2VecKind
+// static_cast an invalid int to IR2VecKind
+#ifndef NDEBUG
+#if GTEST_HAS_DEATH_TEST
+ EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
+ "Unknown IR2Vec kind");
+#endif // GTEST_HAS_DEATH_TEST
+#else
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));
-
- std::string ErrMsg;
- llvm::handleAllErrors(
- Result.takeError(),
- [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
- EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+#endif // NDEBUG
}
TEST(IR2VecTest, LookupVocab) {
@@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test {
Instruction *AddInst = nullptr;
Instruction *RetInst = nullptr;
- float OriginalOpcWeight = ::OpcWeight;
- float OriginalTypeWeight = ::TypeWeight;
- float OriginalArgWeight = ::ArgWeight;
-
void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.25, 0.25}},
@@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test {
};
TEST_F(IR2VecTestFixture, GetInstVecMap) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
}
TEST_F(IR2VecTestFixture, GetBBVecMap) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
}
TEST_F(IR2VecTestFixture, GetBBVector) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
}
TEST_F(IR2VecTestFixture, GetFunctionVector) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();
|
ac378a9
to
2657262
Compare
cc133a1
to
1a051f1
Compare
2657262
to
730ab91
Compare
0d92141
to
d71dd50
Compare
d31d756
to
32d16aa
Compare
d71dd50
to
ea224df
Compare
This change simplifies the API by removing the error handling complexity.
Embedder::create()
to returnstd::unique_ptr<Embedder>
directly instead ofExpected<std::unique_ptr<Embedder>>
(Tracking issue - #141817)