Skip to content

[IR2Vec] Simplifying creation of Embedder #143999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/svkeerthy/06-12-_ir2vec_scale_vocab
Choose a base branch
from

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Jun 12, 2025

This change simplifies the API by removing the error handling complexity.

  • Changed Embedder::create() to return std::unique_ptr<Embedder> directly instead of Expected<std::unique_ptr<Embedder>>
  • Updated documentation and tests to reflect the new API
  • Added death test for invalid IR2Vec kind in debug mode
  • In release mode, simply returns nullptr for invalid kinds instead of creating an error

(Tracking issue - #141817)

Copy link
Contributor Author

svkeerthy commented Jun 12, 2025

@svkeerthy svkeerthy changed the title Simplifying creation of Embedder [IR2Vec] Simplifying creation of Embedder Jun 12, 2025
Copy link
Contributor Author

@albertcohen

@svkeerthy svkeerthy marked this pull request as ready for review June 12, 2025 23:57
@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-mlgo

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

Changes

This change simplifies the API by removing the error handling complexity.

  • Changed Embedder::create() to return std::unique_ptr&lt;Embedder&gt; directly instead of Expected&lt;std::unique_ptr&lt;Embedder&gt;&gt;
  • Updated documentation and tests to reflect the new API
  • Added death test for invalid IR2Vec kind in debug mode
  • In release mode, simply returns nullptr for invalid kinds instead of creating an error

(Tracking issue - #141817)


Full diff: https://github.com/llvm/llvm-project/pull/143999.diff

6 Files Affected:

  • (modified) llvm/docs/MLGO.rst (+1-6)
  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+2-2)
  • (modified) llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp (+3-7)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+8-11)
  • (modified) llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp (+3-4)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+17-27)
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 4f8fb3f59ca19..e7bba9995b75b 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -479,14 +479,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
 
       // Assuming F is an llvm::Function&
       // For example, using IR2VecKind::Symbolic:
-      Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
+      std::unique_ptr<ir2vec::Embedder> Emb =
           ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
 
-      if (auto Err = EmbOrErr.takeError()) {
-        // Handle error in embedder creation
-        return;
-      }
-      std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
 
 3. **Compute and Access Embeddings**:
    Call ``getFunctionVector()`` to get the embedding for the function. 
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index f1aaf4cd2e013..6efa6eac56af9 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,8 +170,8 @@ class Embedder {
   virtual ~Embedder() = default;
 
   /// Factory method to create an Embedder object.
-  static Expected<std::unique_ptr<Embedder>>
-  create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
+  static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
+                                          const Vocab &Vocabulary);
 
   /// Returns a map containing instructions and the corresponding embeddings for
   /// the function F if it has been computed. If not, it computes the embeddings
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 29d3aaf46dc06..dd4eb7f0df053 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
     // We instantiate the IR2Vec embedder each time, as having an unique
     // pointer to the embedder as member of the class would make it
     // non-copyable. Instantiating the embedder in itself is not costly.
-    auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+    auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
                                              *BB.getParent(), *IR2VecVocab);
-    if (Error Err = EmbOrErr.takeError()) {
-      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
-        BB.getContext().emitError("Error creating IR2Vec embeddings: " +
-                                  EI.message());
-      });
+    if (!Embedder) {
+      BB.getContext().emitError("Error creating IR2Vec embeddings");
       return;
     }
-    auto Embedder = std::move(*EmbOrErr);
     const auto &BBEmbedding = Embedder->getBBVector(BB);
     // Subtract BBEmbedding from Function embedding if the direction is -1,
     // and add it if the direction is +1.
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index f51d3252d6606..68026618449d8 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
       Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
       TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
 
-Expected<std::unique_ptr<Embedder>>
-Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
+std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
+                                           const Vocab &Vocabulary) {
   switch (Mode) {
   case IR2VecKind::Symbolic:
     return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
   }
-  return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
+  llvm_unreachable("Unknown IR2Vec kind");
+  return nullptr;
 }
 
 // FIXME: Currently lookups are string based. Use numeric Keys
@@ -389,17 +390,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
 
   auto Vocab = IR2VecVocabResult.getVocabulary();
   for (Function &F : M) {
-    Expected<std::unique_ptr<Embedder>> EmbOrErr =
+    std::unique_ptr<Embedder> Emb =
         Embedder::create(IR2VecKind::Symbolic, F, Vocab);
-    if (auto Err = EmbOrErr.takeError()) {
-      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
-        OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
-      });
+    if (!Emb) {
+      OS << "Error creating IR2Vec embeddings \n";
       continue;
     }
 
-    std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-
     OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
     OS << "Function vector: ";
     Emb->getFunctionVector().print(OS);
@@ -442,4 +439,4 @@ PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
   }
 
   return PreservedAnalyses::all();
-}
\ No newline at end of file
+}
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index e50486bcbcb27..ca4f5d0f63026 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
   }
 
   std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
-    auto EmbResult =
-        ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
-    EXPECT_TRUE(static_cast<bool>(EmbResult));
-    return std::move(*EmbResult);
+    auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+    EXPECT_TRUE(static_cast<bool>(Emb));
+    return std::move(Emb);
   }
 };
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index c3ed6e90cd8fc..05af55b59323b 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  EXPECT_TRUE(static_cast<bool>(Result));
-
-  auto *Emb = Result->get();
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   EXPECT_NE(Emb, nullptr);
 }
 
@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  // static_cast an invalid int to IR2VecKind
+// static_cast an invalid int to IR2VecKind
+#ifndef NDEBUG
+#if GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
+               "Unknown IR2Vec kind");
+#endif // GTEST_HAS_DEATH_TEST
+#else
   auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
   EXPECT_FALSE(static_cast<bool>(Result));
-
-  std::string ErrMsg;
-  llvm::handleAllErrors(
-      Result.takeError(),
-      [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
-  EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+#endif // NDEBUG
 }
 
 TEST(IR2VecTest, LookupVocab) {
@@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test {
   Instruction *AddInst = nullptr;
   Instruction *RetInst = nullptr;
 
-  float OriginalOpcWeight = ::OpcWeight;
-  float OriginalTypeWeight = ::TypeWeight;
-  float OriginalArgWeight = ::ArgWeight;
-
   void SetUp() override {
     V = {{"add", {1.0, 2.0}},
          {"integerTy", {0.25, 0.25}},
@@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test {
 };
 
 TEST_F(IR2VecTestFixture, GetInstVecMap) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &InstMap = Emb->getInstVecMap();
 
@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
 }
 
 TEST_F(IR2VecTestFixture, GetBBVecMap) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &BBMap = Emb->getBBVecMap();
 
@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
 }
 
 TEST_F(IR2VecTestFixture, GetBBVector) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &BBVec = Emb->getBBVector(*BB);
 
@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
 }
 
 TEST_F(IR2VecTestFixture, GetFunctionVector) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &FuncVec = Emb->getFunctionVector();
 

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch from ac378a9 to 2657262 Compare June 13, 2025 00:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch from cc133a1 to 1a051f1 Compare June 13, 2025 00:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch from 2657262 to 730ab91 Compare June 13, 2025 17:46
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch 2 times, most recently from 0d92141 to d71dd50 Compare June 13, 2025 18:18
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch 2 times, most recently from d31d756 to 32d16aa Compare June 17, 2025 18:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch from d71dd50 to ea224df Compare June 17, 2025 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants