Skip to content

Commit df8f54d

Browse files
martin-gaievskiYeonghyeonKO
authored andcommitted
[Performance Improvement] Add custom bulk scorer for hybrid query (2-3x faster) (opensearch-project#1289)
Signed-off-by: Martin Gaievski <[email protected]> Signed-off-by: yeonghyeonKo <[email protected]>
1 parent c9d06bc commit df8f54d

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

src/main/java/org/opensearch/neuralsearch/query/HybridQueryDocIdStream.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package org.opensearch.neuralsearch.query;
66

77
import lombok.RequiredArgsConstructor;
8+
import lombok.Setter;
89
import org.apache.lucene.search.CheckedIntConsumer;
910
import org.apache.lucene.search.DocIdStream;
1011
import org.apache.lucene.util.FixedBitSet;

src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,113 @@ public void testWithRandomDocuments_whenOneSubScorer_thenReturnSuccessfully() {
5555
testWithQuery(docs, scores, hybridQueryScorer);
5656
}
5757

58+
@SneakyThrows
59+
public void testWithRandomDocumentsAndHybridScores_whenMultipleScorers_thenReturnSuccessfully() {
60+
int maxDocId1 = TestUtil.nextInt(random(), 10, 10_000);
61+
Pair<int[], float[]> docsAndScores1 = generateDocuments(maxDocId1);
62+
int[] docs1 = docsAndScores1.getLeft();
63+
float[] scores1 = docsAndScores1.getRight();
64+
int maxDocId2 = TestUtil.nextInt(random(), 10, 10_000);
65+
Pair<int[], float[]> docsAndScores2 = generateDocuments(maxDocId2);
66+
int[] docs2 = docsAndScores2.getLeft();
67+
float[] scores2 = docsAndScores2.getRight();
68+
69+
Weight weight = mock(Weight.class);
70+
71+
HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
72+
Arrays.asList(
73+
scorer(docs1, scores1, fakeWeight(new MatchAllDocsQuery())),
74+
scorer(docs2, scores2, fakeWeight(new MatchNoDocsQuery()))
75+
)
76+
);
77+
int doc = -1;
78+
int numOfActualDocs = 0;
79+
Set<Integer> uniqueDocs1 = Arrays.stream(docs1).boxed().collect(Collectors.toSet());
80+
Set<Integer> uniqueDocs2 = Arrays.stream(docs2).boxed().collect(Collectors.toSet());
81+
while (doc != NO_MORE_DOCS) {
82+
doc = hybridQueryScorer.iterator().nextDoc();
83+
if (doc == DocIdSetIterator.NO_MORE_DOCS) {
84+
continue;
85+
}
86+
float[] actualTotalScores = hybridQueryScorer.hybridScores();
87+
float actualTotalScore = 0.0f;
88+
for (float score : actualTotalScores) {
89+
actualTotalScore += score;
90+
}
91+
float expectedScore = 0.0f;
92+
if (uniqueDocs1.contains(doc)) {
93+
int idx = Arrays.binarySearch(docs1, doc);
94+
expectedScore += scores1[idx];
95+
}
96+
if (uniqueDocs2.contains(doc)) {
97+
int idx = Arrays.binarySearch(docs2, doc);
98+
expectedScore += scores2[idx];
99+
}
100+
assertEquals(expectedScore, actualTotalScore, DELTA_FOR_SCORE_ASSERTION);
101+
numOfActualDocs++;
102+
}
103+
104+
int totalUniqueCount = uniqueDocs1.size();
105+
for (int n : uniqueDocs2) {
106+
if (!uniqueDocs1.contains(n)) {
107+
totalUniqueCount++;
108+
}
109+
}
110+
assertEquals(totalUniqueCount, numOfActualDocs);
111+
}
112+
113+
@SneakyThrows
114+
public void testWithRandomDocumentsAndCombinedScore_whenMultipleScorers_thenReturnSuccessfully() {
115+
int maxDocId1 = TestUtil.nextInt(random(), 10, 10_000);
116+
Pair<int[], float[]> docsAndScores1 = generateDocuments(maxDocId1);
117+
int[] docs1 = docsAndScores1.getLeft();
118+
float[] scores1 = docsAndScores1.getRight();
119+
int maxDocId2 = TestUtil.nextInt(random(), 10, 10_000);
120+
Pair<int[], float[]> docsAndScores2 = generateDocuments(maxDocId2);
121+
int[] docs2 = docsAndScores2.getLeft();
122+
float[] scores2 = docsAndScores2.getRight();
123+
124+
HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(
125+
Arrays.asList(
126+
scorer(docs1, scores1, fakeWeight(new MatchAllDocsQuery())),
127+
scorer(docs2, scores2, fakeWeight(new MatchNoDocsQuery()))
128+
)
129+
);
130+
int doc = -1;
131+
int numOfActualDocs = 0;
132+
Set<Integer> uniqueDocs1 = Arrays.stream(docs1).boxed().collect(Collectors.toSet());
133+
Set<Integer> uniqueDocs2 = Arrays.stream(docs2).boxed().collect(Collectors.toSet());
134+
while (doc != NO_MORE_DOCS) {
135+
doc = hybridQueryScorer.iterator().nextDoc();
136+
if (doc == DocIdSetIterator.NO_MORE_DOCS) {
137+
continue;
138+
}
139+
float expectedScore = 0.0f;
140+
if (uniqueDocs1.contains(doc)) {
141+
int idx = Arrays.binarySearch(docs1, doc);
142+
expectedScore += scores1[idx];
143+
}
144+
if (uniqueDocs2.contains(doc)) {
145+
int idx = Arrays.binarySearch(docs2, doc);
146+
expectedScore += scores2[idx];
147+
}
148+
float hybridScore = 0.0f;
149+
for (float score : hybridQueryScorer.hybridScores()) {
150+
hybridScore += score;
151+
}
152+
assertEquals(expectedScore, hybridScore, DELTA_FOR_SCORE_ASSERTION);
153+
numOfActualDocs++;
154+
}
155+
156+
int totalUniqueCount = uniqueDocs1.size();
157+
for (int n : uniqueDocs2) {
158+
if (!uniqueDocs1.contains(n)) {
159+
totalUniqueCount++;
160+
}
161+
}
162+
assertEquals(totalUniqueCount, numOfActualDocs);
163+
}
164+
58165
@SneakyThrows
59166
public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenReturnSuccessfully() {
60167
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
@@ -92,6 +199,11 @@ public void testMaxScore_whenMultipleScorers_thenSuccessful() {
92199

93200
maxScore = hybridQueryScorerWithSomeNullSubScorers.getMaxScore(Integer.MAX_VALUE);
94201
assertTrue(maxScore > 0.0f);
202+
203+
HybridQueryScorer hybridQueryScorerWithAllNullSubScorers = new HybridQueryScorer(Arrays.asList(null, null));
204+
205+
maxScore = hybridQueryScorerWithAllNullSubScorers.getMaxScore(Integer.MAX_VALUE);
206+
assertEquals(0.0f, maxScore, 0.0f);
95207
}
96208

97209
@SneakyThrows
@@ -402,6 +514,14 @@ public void testScore_whenMultipleQueries_thenCombineScores() {
402514
assertEquals("Combined score should be sum of bool and neural scores", 1.6f, combinedScore, DELTA_FOR_SCORE_ASSERTION);
403515
}
404516

517+
@SneakyThrows
518+
public void testScore_whenEmptySubScorers_thenReturnZero() {
519+
HybridQueryScorer hybridScorer = new HybridQueryScorer(Collections.emptyList());
520+
float score = hybridScorer.score(null);
521+
522+
assertEquals("Score should be 0.0 for null wrapper", 0.0f, score, DELTA_FOR_SCORE_ASSERTION);
523+
}
524+
405525
@SneakyThrows
406526
public void testInitialization_whenValidScorer_thenSuccessful() {
407527
// Create scorer with iterator
@@ -435,6 +555,46 @@ public void testInitialization_whenValidScorer_thenSuccessful() {
435555
assertEquals("Cost should be 1", 1L, wrapper.cost);
436556
}
437557

558+
@SneakyThrows
559+
public void testHybridScores_withTwoPhaseIterator() throws IOException {
560+
// Create weight and scorers
561+
Scorer scorer1 = mock(Scorer.class);
562+
TwoPhaseIterator twoPhaseIterator = mock(TwoPhaseIterator.class);
563+
DocIdSetIterator approximation = mock(DocIdSetIterator.class);
564+
565+
// Setup two-phase behavior
566+
when(scorer1.twoPhaseIterator()).thenReturn(twoPhaseIterator);
567+
when(twoPhaseIterator.approximation()).thenReturn(approximation);
568+
when(scorer1.iterator()).thenReturn(approximation);
569+
when(approximation.cost()).thenReturn(1L);
570+
571+
// Setup DocIdSetIterator behavior - use different docIDs
572+
when(approximation.docID()).thenReturn(5); // approximation at doc 5
573+
when(scorer1.docID()).thenReturn(5); // scorer at same doc
574+
when(scorer1.score()).thenReturn(2.0f);
575+
576+
// matches() always returns false - document should never match
577+
when(twoPhaseIterator.matches()).thenReturn(false);
578+
579+
// Create HybridQueryScorer with two-phase iterator
580+
List<Scorer> subScorers = Collections.singletonList(scorer1);
581+
HybridQueryScorer hybridScorer = new HybridQueryScorer(subScorers);
582+
583+
// Call matches() first to establish non-matching state
584+
TwoPhaseIterator hybridTwoPhase = hybridScorer.twoPhaseIterator();
585+
assertNotNull("Should have two phase iterator", hybridTwoPhase);
586+
assertFalse("Document should not match", hybridTwoPhase.matches());
587+
588+
// Get scores - should be zero since document doesn't match
589+
float[] scores = hybridScorer.hybridScores();
590+
assertEquals("Should have one score entry", 1, scores.length);
591+
assertEquals("Score should be 0 for non-matching document", 0.0f, scores[0], DELTA_FOR_SCORE_ASSERTION);
592+
593+
// Verify score() was never called since document didn't match
594+
verify(scorer1, never()).score();
595+
verify(twoPhaseIterator, times(1)).matches();
596+
}
597+
438598
@SneakyThrows
439599
public void testTwoPhaseIterator_withNestedTwoPhaseQuery() {
440600
// Create a scorer that uses two-phase iteration

0 commit comments

Comments
 (0)