diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst index 28095447f6a5a..0b849f3382f63 100644 --- a/llvm/docs/MLGO.rst +++ b/llvm/docs/MLGO.rst @@ -482,14 +482,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance. // Assuming F is an llvm::Function& // For example, using IR2VecKind::Symbolic: - Expected> EmbOrErr = + std::unique_ptr Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); - if (auto Err = EmbOrErr.takeError()) { - // Handle error in embedder creation - return; - } - std::unique_ptr Emb = std::move(*EmbOrErr); 3. **Compute and Access Embeddings**: Call ``getFunctionVector()`` to get the embedding for the function. diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 2a7a6edda70a8..06312562060aa 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -170,8 +170,8 @@ class Embedder { virtual ~Embedder() = default; /// Factory method to create an Embedder object. - static Expected> - create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary); + static std::unique_ptr create(IR2VecKind Mode, const Function &F, + const Vocab &Vocabulary); /// Returns a map containing instructions and the corresponding embeddings for /// the function F if it has been computed. If not, it computes the embeddings diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp index 29d3aaf46dc06..dd4eb7f0df053 100644 --- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp +++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp @@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB, // We instantiate the IR2Vec embedder each time, as having an unique // pointer to the embedder as member of the class would make it // non-copyable. Instantiating the embedder in itself is not costly. - auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic, + auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic, *BB.getParent(), *IR2VecVocab); - if (Error Err = EmbOrErr.takeError()) { - handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { - BB.getContext().emitError("Error creating IR2Vec embeddings: " + - EI.message()); - }); + if (!Embedder) { + BB.getContext().emitError("Error creating IR2Vec embeddings"); return; } - auto Embedder = std::move(*EmbOrErr); const auto &BBEmbedding = Embedder->getBBVector(BB); // Subtract BBEmbedding from Function embedding if the direction is -1, // and add it if the direction is +1. diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 7ff7acebedf4e..27cc2a4109879 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary) Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {} -Expected> -Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) { +std::unique_ptr Embedder::create(IR2VecKind Mode, const Function &F, + const Vocab &Vocabulary) { switch (Mode) { case IR2VecKind::Symbolic: return std::make_unique(F, Vocabulary); } - return make_error("Unknown IR2VecKind", errc::invalid_argument); + llvm_unreachable("Unknown IR2Vec kind"); + return nullptr; } // FIXME: Currently lookups are string based. Use numeric Keys @@ -384,17 +385,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, auto Vocab = IR2VecVocabResult.getVocabulary(); for (Function &F : M) { - Expected> EmbOrErr = + std::unique_ptr Emb = Embedder::create(IR2VecKind::Symbolic, F, Vocab); - if (auto Err = EmbOrErr.takeError()) { - handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { - OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n"; - }); + if (!Emb) { + OS << "Error creating IR2Vec embeddings \n"; continue; } - std::unique_ptr Emb = std::move(*EmbOrErr); - OS << "IR2Vec embeddings for function " << F.getName() << ":\n"; OS << "Function vector: "; Emb->getFunctionVector().print(OS); diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp index e50486bcbcb27..ca4f5d0f63026 100644 --- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp +++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp @@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test { } std::unique_ptr createEmbedder(const Function &F) { - auto EmbResult = - ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); - EXPECT_TRUE(static_cast(EmbResult)); - return std::move(*EmbResult); + auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); + EXPECT_TRUE(static_cast(Emb)); + return std::move(Emb); } }; diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index c3ed6e90cd8fc..05af55b59323b 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) { FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); - EXPECT_TRUE(static_cast(Result)); - - auto *Emb = Result->get(); + auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); EXPECT_NE(Emb, nullptr); } @@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) { FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); - // static_cast an invalid int to IR2VecKind +// static_cast an invalid int to IR2VecKind +#ifndef NDEBUG +#if GTEST_HAS_DEATH_TEST + EXPECT_DEATH(Embedder::create(static_cast(-1), *F, V), + "Unknown IR2Vec kind"); +#endif // GTEST_HAS_DEATH_TEST +#else auto Result = Embedder::create(static_cast(-1), *F, V); EXPECT_FALSE(static_cast(Result)); - - std::string ErrMsg; - llvm::handleAllErrors( - Result.takeError(), - [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); }); - EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos); +#endif // NDEBUG } TEST(IR2VecTest, LookupVocab) { @@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test { Instruction *AddInst = nullptr; Instruction *RetInst = nullptr; - float OriginalOpcWeight = ::OpcWeight; - float OriginalTypeWeight = ::TypeWeight; - float OriginalArgWeight = ::ArgWeight; - void SetUp() override { V = {{"add", {1.0, 2.0}}, {"integerTy", {0.25, 0.25}}, @@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test { }; TEST_F(IR2VecTestFixture, GetInstVecMap) { - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); - ASSERT_TRUE(static_cast(Result)); - auto Emb = std::move(*Result); + auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast(Emb)); const auto &InstMap = Emb->getInstVecMap(); @@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) { } TEST_F(IR2VecTestFixture, GetBBVecMap) { - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); - ASSERT_TRUE(static_cast(Result)); - auto Emb = std::move(*Result); + auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast(Emb)); const auto &BBMap = Emb->getBBVecMap(); @@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) { } TEST_F(IR2VecTestFixture, GetBBVector) { - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); - ASSERT_TRUE(static_cast(Result)); - auto Emb = std::move(*Result); + auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast(Emb)); const auto &BBVec = Emb->getBBVector(*BB); @@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) { } TEST_F(IR2VecTestFixture, GetFunctionVector) { - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); - ASSERT_TRUE(static_cast(Result)); - auto Emb = std::move(*Result); + auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast(Emb)); const auto &FuncVec = Emb->getFunctionVector();