Skip to content

[IR2Vec] Exposing Embedding as an data type wrapped around std::vector<double> #143197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Jun 6, 2025

Currently Embedding is std::vector<double>. This PR makes it a data type wrapped around std::vector<double> to overload basic arithmetic operators and expose comparison operations. It simplifies the usage here and in the passes where operations on Embedding would be performed.

(Tracking issue - #141817)

Copy link
Contributor Author

svkeerthy commented Jun 6, 2025

@svkeerthy svkeerthy changed the title Embedding [IR2Vec] Exposing Embedding as ADT wrapped around std::vector<double> Jun 6, 2025
@svkeerthy svkeerthy marked this pull request as ready for review June 6, 2025 20:15
@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

Changes

Currently Embedding is std::vector&lt;double&gt;. This PR makes it an ADT wrapped around std::vector&lt;double&gt; to overload basic arithmetic operators and expose comparison operations. It simplifies the usage here and in the passes where operations on Embedding would be performed.


Full diff: https://github.com/llvm/llvm-project/pull/143197.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+18-8)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+39-19)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+54-14)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 9fd1b0ae8e248..930b13f079796 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -53,7 +53,24 @@ class raw_ostream;
 enum class IR2VecKind { Symbolic };
 
 namespace ir2vec {
-using Embedding = std::vector<double>;
+/// Embedding is a ADT that wraps std::vector<double>. It provides
+/// additional functionality for arithmetic and comparison operations.
+struct Embedding : public std::vector<double> {
+  using std::vector<double>::vector;
+
+  /// Arithmetic operators
+  Embedding &operator+=(const Embedding &RHS);
+  Embedding &operator-=(const Embedding &RHS);
+
+  /// Adds Src Embedding scaled by Factor with the called Embedding.
+  /// Called_Embedding += Src * Factor
+  void scaleAndAdd(const Embedding &Src, float Factor);
+
+  /// Returns true if the embedding is approximately equal to the RHS embedding
+  /// within the specified tolerance.
+  bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
+};
+
 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 // FIXME: Current the keys are strings. This can be changed to
@@ -99,13 +116,6 @@ class Embedder {
   /// zero vector.
   Embedding lookupVocab(const std::string &Key) const;
 
-  /// Adds two vectors: Dst += Src
-  static void addVectors(Embedding &Dst, const Embedding &Src);
-
-  /// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
-  static void addScaledVector(Embedding &Dst, const Embedding &Src,
-                              float Factor);
-
 public:
   virtual ~Embedder() = default;
 
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 490db5fdcdf99..8ee8e5b0ff74e 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -55,6 +55,40 @@ static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
 
 AnalysisKey IR2VecVocabAnalysis::Key;
 
+// ==----------------------------------------------------------------------===//
+// Embedding
+//===----------------------------------------------------------------------===//
+
+Embedding &Embedding::operator+=(const Embedding &RHS) {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
+                 std::plus<double>());
+  return *this;
+}
+
+Embedding &Embedding::operator-=(const Embedding &RHS) {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
+                 std::minus<double>());
+  return *this;
+}
+
+void Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
+  assert(this->size() == Src.size() && "Vectors must have the same dimension");
+  for (size_t i = 0; i < this->size(); ++i) {
+    (*this)[i] += Src[i] * Factor;
+  }
+}
+
+bool Embedding::approximatelyEquals(const Embedding &RHS,
+                                    double Tolerance) const {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  for (size_t i = 0; i < this->size(); ++i)
+    if (std::abs((*this)[i] - RHS[i]) > Tolerance)
+      return false;
+  return true;
+}
+
 // ==----------------------------------------------------------------------===//
 // Embedder and its subclasses
 //===----------------------------------------------------------------------===//
@@ -73,20 +107,6 @@ Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
   return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
 }
 
-void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
-  assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
-  std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
-                 std::plus<double>());
-}
-
-void Embedder::addScaledVector(Embedding &Dst, const Embedding &Src,
-                               float Factor) {
-  assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
-  for (size_t i = 0; i < Dst.size(); ++i) {
-    Dst[i] += Src[i] * Factor;
-  }
-}
-
 // FIXME: Currently lookups are string based. Use numeric Keys
 // for efficiency
 Embedding Embedder::lookupVocab(const std::string &Key) const {
@@ -164,20 +184,20 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
     Embedding InstVector(Dimension, 0);
 
     const auto OpcVec = lookupVocab(I.getOpcodeName());
-    addScaledVector(InstVector, OpcVec, OpcWeight);
+    InstVector.scaleAndAdd(OpcVec, OpcWeight);
 
     // FIXME: Currently lookups are string based. Use numeric Keys
     // for efficiency.
     const auto Type = I.getType();
     const auto TypeVec = getTypeEmbedding(Type);
-    addScaledVector(InstVector, TypeVec, TypeWeight);
+    InstVector.scaleAndAdd(TypeVec, TypeWeight);
 
     for (const auto &Op : I.operands()) {
       const auto OperandVec = getOperandEmbedding(Op.get());
-      addScaledVector(InstVector, OperandVec, ArgWeight);
+      InstVector.scaleAndAdd(OperandVec, ArgWeight);
     }
     InstVecMap[&I] = InstVector;
-    addVectors(BBVector, InstVector);
+    BBVector += InstVector;
   }
   BBVecMap[&BB] = BBVector;
 }
@@ -187,7 +207,7 @@ void SymbolicEmbedder::computeEmbeddings() const {
     return;
   for (const auto &BB : F) {
     computeEmbeddings(BB);
-    addVectors(FuncVector, BBVecMap[&BB]);
+    FuncVector += BBVecMap[&BB];
   }
 }
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9e47b2cd8bedd..7259a8a2fe20a 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -32,13 +32,6 @@ class TestableEmbedder : public Embedder {
   void computeEmbeddings() const override {}
   void computeEmbeddings(const BasicBlock &BB) const override {}
   using Embedder::lookupVocab;
-  static void addVectors(Embedding &Dst, const Embedding &Src) {
-    Embedder::addVectors(Dst, Src);
-  }
-  static void addScaledVector(Embedding &Dst, const Embedding &Src,
-                              float Factor) {
-    Embedder::addScaledVector(Dst, Src, Factor);
-  }
 };
 
 TEST(IR2VecTest, CreateSymbolicEmbedder) {
@@ -79,37 +72,83 @@ TEST(IR2VecTest, AddVectors) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {0.5, 1.5, -1.0};
 
-  TestableEmbedder::addVectors(E1, E2);
+  E1 += E2;
   EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
 
   // Check that E2 is unchanged
   EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
 }
 
+TEST(IR2VecTest, SubtractVectors) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {0.5, 1.5, -1.0};
+
+  E1 -= E2;
+  EXPECT_THAT(E1, ElementsAre(0.5, 0.5, 4.0));
+
+  // Check that E2 is unchanged
+  EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
 TEST(IR2VecTest, AddScaledVector) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {2.0, 0.5, -1.0};
 
-  TestableEmbedder::addScaledVector(E1, E2, 0.5f);
+  E1.scaleAndAdd(E2, 0.5f);
   EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
 
   // Check that E2 is unchanged
   EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
 }
 
+TEST(IR2VecTest, ApproximatelyEqual) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {1.0000001, 2.0000001, 3.0000001};
+  EXPECT_TRUE(E1.approximatelyEquals(E2)); // Diff = 1e-7
+
+  Embedding E3 = {1.00002, 2.00002, 3.00002}; // Diff = 2e-5
+  EXPECT_FALSE(E1.approximatelyEquals(E3));
+  EXPECT_TRUE(E1.approximatelyEquals(E3, 3e-5));
+
+  Embedding E_clearly_within = {1.0000005, 2.0000005, 3.0000005}; // Diff = 5e-7
+  EXPECT_TRUE(E1.approximatelyEquals(E_clearly_within));
+
+  Embedding E_clearly_outside = {1.00001, 2.00001, 3.00001}; // Diff = 1e-5
+  EXPECT_FALSE(E1.approximatelyEquals(E_clearly_outside));
+
+  Embedding E4 = {1.0, 2.0, 3.5}; // Large diff
+  EXPECT_FALSE(E1.approximatelyEquals(E4, 0.01));
+
+  Embedding E5 = {1.0, 2.0, 3.0};
+  EXPECT_TRUE(E1.approximatelyEquals(E5, 0.0));
+  EXPECT_TRUE(E1.approximatelyEquals(E5));
+}
+
 #if GTEST_HAS_DEATH_TEST
 #ifndef NDEBUG
 TEST(IR2VecTest, MismatchedDimensionsAddVectors) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};
-  EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
-               "Vectors must have the same dimension");
+  EXPECT_DEATH(E1 += E2, "Vectors must have the same dimension");
+}
+
+TEST(IR2VecTest, MismatchedDimensionsSubtractVectors) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.0};
+  EXPECT_DEATH(E1 -= E2, "Vectors must have the same dimension");
 }
 
 TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};
-  EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
+  EXPECT_DEATH(E1.scaleAndAdd(E2, 1.0f),
+               "Vectors must have the same dimension");
+}
+
+TEST(IR2VecTest, MismatchedDimensionsApproximatelyEqual) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.010};
+  EXPECT_DEATH(E1.approximatelyEquals(E2),
                "Vectors must have the same dimension");
 }
 #endif // NDEBUG
@@ -136,8 +175,9 @@ TEST(IR2VecTest, ZeroDimensionEmbedding) {
   Embedding E1;
   Embedding E2;
   // Should be no-op, but not crash
-  TestableEmbedder::addVectors(E1, E2);
-  TestableEmbedder::addScaledVector(E1, E2, 1.0f);
+  E1 += E2;
+  E1 -= E2;
+  E1.scaleAndAdd(E2, 1.0f);
   EXPECT_TRUE(E1.empty());
 }
 

@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

Currently Embedding is std::vector&lt;double&gt;. This PR makes it an ADT wrapped around std::vector&lt;double&gt; to overload basic arithmetic operators and expose comparison operations. It simplifies the usage here and in the passes where operations on Embedding would be performed.


Full diff: https://github.com/llvm/llvm-project/pull/143197.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+18-8)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+39-19)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+54-14)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 9fd1b0ae8e248..930b13f079796 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -53,7 +53,24 @@ class raw_ostream;
 enum class IR2VecKind { Symbolic };
 
 namespace ir2vec {
-using Embedding = std::vector<double>;
+/// Embedding is a ADT that wraps std::vector<double>. It provides
+/// additional functionality for arithmetic and comparison operations.
+struct Embedding : public std::vector<double> {
+  using std::vector<double>::vector;
+
+  /// Arithmetic operators
+  Embedding &operator+=(const Embedding &RHS);
+  Embedding &operator-=(const Embedding &RHS);
+
+  /// Adds Src Embedding scaled by Factor with the called Embedding.
+  /// Called_Embedding += Src * Factor
+  void scaleAndAdd(const Embedding &Src, float Factor);
+
+  /// Returns true if the embedding is approximately equal to the RHS embedding
+  /// within the specified tolerance.
+  bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
+};
+
 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 // FIXME: Current the keys are strings. This can be changed to
@@ -99,13 +116,6 @@ class Embedder {
   /// zero vector.
   Embedding lookupVocab(const std::string &Key) const;
 
-  /// Adds two vectors: Dst += Src
-  static void addVectors(Embedding &Dst, const Embedding &Src);
-
-  /// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
-  static void addScaledVector(Embedding &Dst, const Embedding &Src,
-                              float Factor);
-
 public:
   virtual ~Embedder() = default;
 
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 490db5fdcdf99..8ee8e5b0ff74e 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -55,6 +55,40 @@ static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
 
 AnalysisKey IR2VecVocabAnalysis::Key;
 
+// ==----------------------------------------------------------------------===//
+// Embedding
+//===----------------------------------------------------------------------===//
+
+Embedding &Embedding::operator+=(const Embedding &RHS) {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
+                 std::plus<double>());
+  return *this;
+}
+
+Embedding &Embedding::operator-=(const Embedding &RHS) {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
+                 std::minus<double>());
+  return *this;
+}
+
+void Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
+  assert(this->size() == Src.size() && "Vectors must have the same dimension");
+  for (size_t i = 0; i < this->size(); ++i) {
+    (*this)[i] += Src[i] * Factor;
+  }
+}
+
+bool Embedding::approximatelyEquals(const Embedding &RHS,
+                                    double Tolerance) const {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  for (size_t i = 0; i < this->size(); ++i)
+    if (std::abs((*this)[i] - RHS[i]) > Tolerance)
+      return false;
+  return true;
+}
+
 // ==----------------------------------------------------------------------===//
 // Embedder and its subclasses
 //===----------------------------------------------------------------------===//
@@ -73,20 +107,6 @@ Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
   return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
 }
 
-void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
-  assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
-  std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
-                 std::plus<double>());
-}
-
-void Embedder::addScaledVector(Embedding &Dst, const Embedding &Src,
-                               float Factor) {
-  assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
-  for (size_t i = 0; i < Dst.size(); ++i) {
-    Dst[i] += Src[i] * Factor;
-  }
-}
-
 // FIXME: Currently lookups are string based. Use numeric Keys
 // for efficiency
 Embedding Embedder::lookupVocab(const std::string &Key) const {
@@ -164,20 +184,20 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
     Embedding InstVector(Dimension, 0);
 
     const auto OpcVec = lookupVocab(I.getOpcodeName());
-    addScaledVector(InstVector, OpcVec, OpcWeight);
+    InstVector.scaleAndAdd(OpcVec, OpcWeight);
 
     // FIXME: Currently lookups are string based. Use numeric Keys
     // for efficiency.
     const auto Type = I.getType();
     const auto TypeVec = getTypeEmbedding(Type);
-    addScaledVector(InstVector, TypeVec, TypeWeight);
+    InstVector.scaleAndAdd(TypeVec, TypeWeight);
 
     for (const auto &Op : I.operands()) {
       const auto OperandVec = getOperandEmbedding(Op.get());
-      addScaledVector(InstVector, OperandVec, ArgWeight);
+      InstVector.scaleAndAdd(OperandVec, ArgWeight);
     }
     InstVecMap[&I] = InstVector;
-    addVectors(BBVector, InstVector);
+    BBVector += InstVector;
   }
   BBVecMap[&BB] = BBVector;
 }
@@ -187,7 +207,7 @@ void SymbolicEmbedder::computeEmbeddings() const {
     return;
   for (const auto &BB : F) {
     computeEmbeddings(BB);
-    addVectors(FuncVector, BBVecMap[&BB]);
+    FuncVector += BBVecMap[&BB];
   }
 }
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9e47b2cd8bedd..7259a8a2fe20a 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -32,13 +32,6 @@ class TestableEmbedder : public Embedder {
   void computeEmbeddings() const override {}
   void computeEmbeddings(const BasicBlock &BB) const override {}
   using Embedder::lookupVocab;
-  static void addVectors(Embedding &Dst, const Embedding &Src) {
-    Embedder::addVectors(Dst, Src);
-  }
-  static void addScaledVector(Embedding &Dst, const Embedding &Src,
-                              float Factor) {
-    Embedder::addScaledVector(Dst, Src, Factor);
-  }
 };
 
 TEST(IR2VecTest, CreateSymbolicEmbedder) {
@@ -79,37 +72,83 @@ TEST(IR2VecTest, AddVectors) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {0.5, 1.5, -1.0};
 
-  TestableEmbedder::addVectors(E1, E2);
+  E1 += E2;
   EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
 
   // Check that E2 is unchanged
   EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
 }
 
+TEST(IR2VecTest, SubtractVectors) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {0.5, 1.5, -1.0};
+
+  E1 -= E2;
+  EXPECT_THAT(E1, ElementsAre(0.5, 0.5, 4.0));
+
+  // Check that E2 is unchanged
+  EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
 TEST(IR2VecTest, AddScaledVector) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {2.0, 0.5, -1.0};
 
-  TestableEmbedder::addScaledVector(E1, E2, 0.5f);
+  E1.scaleAndAdd(E2, 0.5f);
   EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
 
   // Check that E2 is unchanged
   EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
 }
 
+TEST(IR2VecTest, ApproximatelyEqual) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {1.0000001, 2.0000001, 3.0000001};
+  EXPECT_TRUE(E1.approximatelyEquals(E2)); // Diff = 1e-7
+
+  Embedding E3 = {1.00002, 2.00002, 3.00002}; // Diff = 2e-5
+  EXPECT_FALSE(E1.approximatelyEquals(E3));
+  EXPECT_TRUE(E1.approximatelyEquals(E3, 3e-5));
+
+  Embedding E_clearly_within = {1.0000005, 2.0000005, 3.0000005}; // Diff = 5e-7
+  EXPECT_TRUE(E1.approximatelyEquals(E_clearly_within));
+
+  Embedding E_clearly_outside = {1.00001, 2.00001, 3.00001}; // Diff = 1e-5
+  EXPECT_FALSE(E1.approximatelyEquals(E_clearly_outside));
+
+  Embedding E4 = {1.0, 2.0, 3.5}; // Large diff
+  EXPECT_FALSE(E1.approximatelyEquals(E4, 0.01));
+
+  Embedding E5 = {1.0, 2.0, 3.0};
+  EXPECT_TRUE(E1.approximatelyEquals(E5, 0.0));
+  EXPECT_TRUE(E1.approximatelyEquals(E5));
+}
+
 #if GTEST_HAS_DEATH_TEST
 #ifndef NDEBUG
 TEST(IR2VecTest, MismatchedDimensionsAddVectors) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};
-  EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
-               "Vectors must have the same dimension");
+  EXPECT_DEATH(E1 += E2, "Vectors must have the same dimension");
+}
+
+TEST(IR2VecTest, MismatchedDimensionsSubtractVectors) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.0};
+  EXPECT_DEATH(E1 -= E2, "Vectors must have the same dimension");
 }
 
 TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};
-  EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
+  EXPECT_DEATH(E1.scaleAndAdd(E2, 1.0f),
+               "Vectors must have the same dimension");
+}
+
+TEST(IR2VecTest, MismatchedDimensionsApproximatelyEqual) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.010};
+  EXPECT_DEATH(E1.approximatelyEquals(E2),
                "Vectors must have the same dimension");
 }
 #endif // NDEBUG
@@ -136,8 +175,9 @@ TEST(IR2VecTest, ZeroDimensionEmbedding) {
   Embedding E1;
   Embedding E2;
   // Should be no-op, but not crash
-  TestableEmbedder::addVectors(E1, E2);
-  TestableEmbedder::addScaledVector(E1, E2, 1.0f);
+  E1 += E2;
+  E1 -= E2;
+  E1.scaleAndAdd(E2, 1.0f);
   EXPECT_TRUE(E1.empty());
 }
 

@svkeerthy svkeerthy changed the title [IR2Vec] Exposing Embedding as ADT wrapped around std::vector<double> [IR2Vec] Exposing Embedding as an ADT wrapped around std::vector<double> Jun 6, 2025
@svkeerthy svkeerthy requested a review from kazutakahirata June 6, 2025 21:09
Copy link
Member

@mtrofin mtrofin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall nice cleanup, I think the inheritance is to be avoided though.

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch from 8e5c13a to 602c2d3 Compare June 9, 2025 18:10
@svkeerthy svkeerthy requested a review from mtrofin June 9, 2025 18:11
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch from 602c2d3 to 6817aa9 Compare June 9, 2025 20:42
Copy link
Contributor

ADT has a well defined meaning in LLVM. On reading the title I expected something new to be added to llvm/ADT but this is local to IR2Vec. Can you drop the usage of ADT in the title and code?

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch 2 times, most recently from 9c05884 to dfe59f2 Compare June 10, 2025 04:30
@svkeerthy svkeerthy changed the title [IR2Vec] Exposing Embedding as an ADT wrapped around std::vector<double> [IR2Vec] Exposing Embedding as an data type wrapped around std::vector<double> Jun 10, 2025
Copy link
Contributor Author

ADT has a well defined meaning in LLVM. On reading the title I expected something new to be added to llvm/ADT but this is local to IR2Vec. Can you drop the usage of ADT in the title and code?

Okay

Copy link
Contributor

@boomanaiden154 boomanaiden154 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit,

Copy link
Contributor

@snehasish snehasish left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch from dfe59f2 to 7a33ef8 Compare June 10, 2025 18:14
Copy link
Contributor Author

svkeerthy commented Jun 10, 2025

Merge activity

  • Jun 10, 10:11 PM UTC: A user started a stack merge that includes this pull request via Graphite.
  • Jun 10, 10:12 PM UTC: @svkeerthy merged this pull request with Graphite.

@svkeerthy svkeerthy merged commit 32649e0 into main Jun 10, 2025
7 checks passed
@svkeerthy svkeerthy deleted the users/svkeerthy/06-06-embedding branch June 10, 2025 22:12
rorth pushed a commit to rorth/llvm-project that referenced this pull request Jun 11, 2025
…r<double> (llvm#143197)

Currently `Embedding` is `std::vector<double>`. This PR makes it a data type wrapped around `std::vector<double>` to overload basic arithmetic operators and expose comparison operations. It _simplifies_ the usage here and in the passes where operations on `Embedding` would be performed.

(Tracking issue - llvm#141817)
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…r<double> (llvm#143197)

Currently `Embedding` is `std::vector<double>`. This PR makes it a data type wrapped around `std::vector<double>` to overload basic arithmetic operators and expose comparison operations. It _simplifies_ the usage here and in the passes where operations on `Embedding` would be performed.

(Tracking issue - llvm#141817)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants