Skip to content

Commit 0b0e31a

Browse files
Added custom bulk scorer for hybrid query
Signed-off-by: Martin Gaievski <[email protected]>
1 parent d42efb1 commit 0b0e31a

18 files changed

+908
-408
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
77

88
### Features
99
- Implement analyzer based neural sparse query ([#1088](https://github.com/opensearch-project/neural-search/pull/1088) [#1279](https://github.com/opensearch-project/neural-search/pull/1279))
10-
- [Semantic Field] Add semantic field mapper. ([#1225](https://github.com/opensearch-project/neural-search/pull/1225)).
10+
- [Semantic Field] Add semantic field mapper. ([#1225](https://github.com/opensearch-project/neural-search/pull/1225))
1111

1212
### Enhancements
13+
- [Performance Improvement] Add custom bulk scorer for hybrid query (2-3x faster) ([#1289](https://github.com/opensearch-project/neural-search/pull/1289))
1314

1415
### Bug Fixes
1516
- Add validations to prevent empty input_text_field and input_image_field in TextImageEmbeddingProcessor ([#1257](https://github.com/opensearch-project/neural-search/pull/1257))
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.Getter;
8+
import org.apache.lucene.search.BulkScorer;
9+
import org.apache.lucene.search.DocIdSetIterator;
10+
import org.apache.lucene.search.LeafCollector;
11+
import org.apache.lucene.search.Scorer;
12+
import org.apache.lucene.util.Bits;
13+
import org.apache.lucene.util.FixedBitSet;
14+
15+
import java.io.IOException;
16+
import java.util.Arrays;
17+
import java.util.List;
18+
import java.util.Objects;
19+
20+
/**
21+
* Bulk scorer for hybrid query
22+
*/
23+
public class HybridBulkScorer extends BulkScorer {
24+
private static final int SHIFT = 10;
25+
private static final int WINDOW_SIZE = 1 << SHIFT;
26+
private static final int MASK = WINDOW_SIZE - 1;
27+
28+
private final long cost;
29+
private final Scorer[] scorers;
30+
@Getter
31+
private final HybridSubQueryScorer hybridSubQueryScorer;
32+
private final boolean needsScores;
33+
@Getter
34+
private final FixedBitSet matching;
35+
@Getter
36+
private final float[][] windowScores;
37+
private final HybridQueryDocIdStream hybridQueryDocIdStream;
38+
private final int maxDoc;
39+
40+
public HybridBulkScorer(List<Scorer> scorers, boolean needsScores, int maxDoc) {
41+
long cost = 0;
42+
this.scorers = new Scorer[scorers.size()];
43+
for (int subQueryIndex = 0; subQueryIndex < scorers.size(); subQueryIndex++) {
44+
Scorer scorer = scorers.get(subQueryIndex);
45+
if (Objects.isNull(scorer)) {
46+
continue;
47+
}
48+
cost += scorer.iterator().cost();
49+
this.scorers[subQueryIndex] = scorer;
50+
}
51+
this.cost = cost;
52+
this.hybridSubQueryScorer = new HybridSubQueryScorer(scorers.size());
53+
this.needsScores = needsScores;
54+
this.matching = new FixedBitSet(WINDOW_SIZE);
55+
this.windowScores = new float[this.scorers.length][WINDOW_SIZE];
56+
this.maxDoc = maxDoc;
57+
this.hybridQueryDocIdStream = new HybridQueryDocIdStream(this);
58+
}
59+
60+
@Override
61+
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
62+
collector.setScorer(hybridSubQueryScorer);
63+
// making sure we are not going over the global limit defined by maxDoc
64+
max = Math.min(max, maxDoc);
65+
int[] docsIds = advance(min, scorers);
66+
while (allDocIdsUsed(docsIds, max) == false) {
67+
scoreWindow(collector, acceptDocs, min, max, docsIds);
68+
}
69+
return getNextDocIdCandidate(docsIds);
70+
}
71+
72+
private void scoreWindow(LeafCollector collector, Bits acceptDocs, int min, int max, int[] docIds) throws IOException {
73+
// pick the lowest out of all not yet used doc ids
74+
int topDoc = -1;
75+
for (int docId : docIds) {
76+
if (docId < max) {
77+
topDoc = docId;
78+
break;
79+
}
80+
}
81+
82+
final int windowBase = topDoc & ~MASK; // take the next match (at random) and find the window where it belongs
83+
final int windowMin = Math.max(min, windowBase);
84+
final int windowMax = Math.min(max, windowBase + WINDOW_SIZE);
85+
86+
scoreWindowIntoBitSetWithSubqueryScorers(collector, acceptDocs, max, docIds, windowMin, windowMax, windowBase);
87+
}
88+
89+
private void scoreWindowIntoBitSetWithSubqueryScorers(
90+
LeafCollector collector,
91+
Bits acceptDocs,
92+
int max,
93+
int[] docIds,
94+
int windowMin,
95+
int windowMax,
96+
int windowBase
97+
) throws IOException {
98+
for (int subQueryIndex = 0; subQueryIndex < scorers.length; subQueryIndex++) {
99+
if (Objects.isNull(scorers[subQueryIndex]) || docIds[subQueryIndex] >= max) {
100+
continue;
101+
}
102+
DocIdSetIterator it = scorers[subQueryIndex].iterator();
103+
int doc = docIds[subQueryIndex];
104+
if (doc < windowMin) {
105+
doc = it.advance(windowMin);
106+
}
107+
while (doc < windowMax) {
108+
if (Objects.isNull(acceptDocs) || acceptDocs.get(doc)) {
109+
int d = doc & MASK;
110+
if (needsScores) {
111+
float score = scorers[subQueryIndex].score();
112+
// collect score only in case it's gt competitive score
113+
if (score > hybridSubQueryScorer.getMinScores()[subQueryIndex]) {
114+
matching.set(d);
115+
windowScores[subQueryIndex][d] = score;
116+
}
117+
} else {
118+
matching.set(d);
119+
}
120+
}
121+
doc = it.nextDoc();
122+
}
123+
docIds[subQueryIndex] = doc;
124+
}
125+
126+
hybridQueryDocIdStream.setBase(windowBase);
127+
collector.collect(hybridQueryDocIdStream);
128+
129+
matching.clear();
130+
131+
for (float[] windowScore : windowScores) {
132+
Arrays.fill(windowScore, 0.0f);
133+
}
134+
}
135+
136+
private int[] advance(int min, Scorer[] scorers) throws IOException {
137+
int[] docIds = new int[scorers.length];
138+
for (int subQueryIndex = 0; subQueryIndex < scorers.length; subQueryIndex++) {
139+
if (Objects.isNull(scorers[subQueryIndex])) {
140+
docIds[subQueryIndex] = DocIdSetIterator.NO_MORE_DOCS;
141+
continue;
142+
}
143+
DocIdSetIterator it = scorers[subQueryIndex].iterator();
144+
int doc = it.docID();
145+
if (doc < min) {
146+
doc = it.advance(min);
147+
}
148+
docIds[subQueryIndex] = doc;
149+
}
150+
return docIds;
151+
}
152+
153+
private boolean allDocIdsUsed(int[] docsIds, int max) {
154+
for (int docId : docsIds) {
155+
if (docId < max) {
156+
return false;
157+
}
158+
}
159+
return true;
160+
}
161+
162+
private int getNextDocIdCandidate(final int[] docsIds) {
163+
int nextDoc = -1;
164+
for (int doc : docsIds) {
165+
if (doc != DocIdSetIterator.NO_MORE_DOCS) {
166+
nextDoc = Math.max(nextDoc, doc);
167+
}
168+
}
169+
return nextDoc == -1 ? DocIdSetIterator.NO_MORE_DOCS : nextDoc;
170+
}
171+
172+
@Override
173+
public long cost() {
174+
return cost;
175+
}
176+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.RequiredArgsConstructor;
8+
import lombok.Setter;
9+
import org.apache.lucene.search.CheckedIntConsumer;
10+
import org.apache.lucene.search.DocIdStream;
11+
import org.apache.lucene.util.FixedBitSet;
12+
13+
import java.io.IOException;
14+
import java.util.Objects;
15+
16+
/**
17+
* This class is used to create a DocIdStream for HybridQuery
18+
*/
19+
@RequiredArgsConstructor
20+
public class HybridQueryDocIdStream extends DocIdStream {
21+
private final HybridBulkScorer hybridBulkScorer;
22+
@Setter
23+
private int base;
24+
25+
/**
26+
* Iterate over all doc ids and collect each doc id with leaf collector
27+
* @param consumer consumer that is called for each accepted doc id
28+
* @throws IOException in case of IO exception
29+
*/
30+
@Override
31+
public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
32+
// bitset that represents matching documents, bit is set (1) if doc id is a match
33+
FixedBitSet matchingBitSet = hybridBulkScorer.getMatching();
34+
long[] bitArray = matchingBitSet.getBits();
35+
// iterate through each block of 64 documents (since each long contains 64 bits)
36+
for (int idx = 0; idx < bitArray.length; idx++) {
37+
long bits = bitArray[idx];
38+
while (bits != 0L) {
39+
// find position of the rightmost set bit (1)
40+
int numberOfTrailingZeros = Long.numberOfTrailingZeros(bits);
41+
// calculate actual document ID within the window
42+
// idx << 6 is equivalent to idx * 64 (block offset)
43+
// numberOfTrailingZeros gives position within the block
44+
final int docIndexInWindow = (idx << 6) | numberOfTrailingZeros;
45+
float[][] windowScores = hybridBulkScorer.getWindowScores();
46+
for (int subQueryIndex = 0; subQueryIndex < windowScores.length; subQueryIndex++) {
47+
if (Objects.isNull(windowScores[subQueryIndex])) {
48+
continue;
49+
}
50+
float scoreOfDocIdForSubQuery = windowScores[subQueryIndex][docIndexInWindow];
51+
hybridBulkScorer.getHybridSubQueryScorer().getSubQueryScores()[subQueryIndex] = scoreOfDocIdForSubQuery;
52+
}
53+
// process the document with its base offset
54+
consumer.accept(base | docIndexInWindow);
55+
// reset scores after processing of one doc, this is required because scorer object is re-used
56+
hybridBulkScorer.getHybridSubQueryScorer().resetScores();
57+
// reset bit for this doc id to indicate that it has been consumed
58+
bits ^= 1L << numberOfTrailingZeros;
59+
}
60+
}
61+
}
62+
}

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.apache.lucene.search.ScoreMode;
1515
import org.apache.lucene.search.Scorer;
1616
import org.apache.lucene.search.TwoPhaseIterator;
17-
import org.apache.lucene.search.Weight;
1817
import org.apache.lucene.util.PriorityQueue;
1918
import org.opensearch.neuralsearch.search.HybridDisiWrapper;
2019

@@ -44,11 +43,11 @@ public class HybridQueryScorer extends Scorer {
4443
private final TwoPhase twoPhase;
4544
private final int numSubqueries;
4645

47-
public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) throws IOException {
48-
this(weight, subScorers, ScoreMode.TOP_SCORES);
46+
public HybridQueryScorer(final List<Scorer> subScorers) throws IOException {
47+
this(subScorers, ScoreMode.TOP_SCORES);
4948
}
5049

51-
HybridQueryScorer(final Weight weight, final List<Scorer> subScorers, final ScoreMode scoreMode) throws IOException {
50+
HybridQueryScorer(final List<Scorer> subScorers, final ScoreMode scoreMode) throws IOException {
5251
super();
5352
this.subScorers = Collections.unmodifiableList(subScorers);
5453
this.numSubqueries = subScorers.size();
@@ -75,7 +74,7 @@ public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) thr
7574
sumMatchCost += w.matchCost * costWeight;
7675
}
7776
}
78-
if (!hasApproximation) { // no sub scorer supports approximations
77+
if (hasApproximation == false) { // no sub scorer supports approximations
7978
twoPhase = null;
8079
} else {
8180
final float matchCost = sumMatchCost / sumApproxCost;
@@ -284,7 +283,7 @@ public boolean matches() throws IOException {
284283
wrapper.next = verifiedMatches;
285284
verifiedMatches = wrapper;
286285

287-
if (!needsScores) {
286+
if (needsScores == false) {
288287
// we can stop here
289288
return true;
290289
}

src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import lombok.Getter;
1616
import lombok.RequiredArgsConstructor;
1717
import org.apache.lucene.index.LeafReaderContext;
18+
import org.apache.lucene.search.BulkScorer;
1819
import org.apache.lucene.search.Explanation;
1920
import org.apache.lucene.search.IndexSearcher;
2021
import org.apache.lucene.search.Matches;
@@ -37,7 +38,6 @@ public final class HybridQueryWeight extends Weight {
3738
// The Weights for our subqueries, in 1-1 correspondence
3839
@Getter(AccessLevel.PACKAGE)
3940
private final List<Weight> weights;
40-
4141
private final ScoreMode scoreMode;
4242

4343
/**
@@ -95,7 +95,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
9595
if (scorerSuppliers.isEmpty()) {
9696
return null;
9797
}
98-
return new HybridScorerSupplier(scorerSuppliers, this, scoreMode);
98+
return new HybridScorerSupplier(scorerSuppliers, this, scoreMode, context);
9999
}
100100

101101
private Void addScoreSupplier(Weight weight, HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier> collector) {
@@ -145,7 +145,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
145145
max = Math.max(max, score);
146146
subsOnMatch.add(e);
147147
} else {
148-
if (!match) {
148+
if (match == false) {
149149
subsOnNoMatch.add(e);
150150
}
151151
subsOnMatch.add(e);
@@ -161,10 +161,23 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
161161

162162
@RequiredArgsConstructor
163163
static class HybridScorerSupplier extends ScorerSupplier {
164+
165+
@Override
166+
public BulkScorer bulkScorer() throws IOException {
167+
List<Scorer> scorers = new ArrayList<>();
168+
for (Weight weight : weight.getWeights()) {
169+
Scorer scorer = weight.scorer(context);
170+
scorers.add(scorer);
171+
}
172+
return new HybridBulkScorer(scorers, scoreMode.needsScores(), context.reader().maxDoc());
173+
}
174+
164175
private long cost = -1;
176+
@Getter
165177
private final List<ScorerSupplier> scorerSuppliers;
166-
private final Weight weight;
178+
private final HybridQueryWeight weight;
167179
private final ScoreMode scoreMode;
180+
private final LeafReaderContext context;
168181

169182
@Override
170183
public Scorer get(long leadCost) throws IOException {
@@ -176,7 +189,7 @@ public Scorer get(long leadCost) throws IOException {
176189
tScorers.add(null);
177190
}
178191
}
179-
return new HybridQueryScorer(weight, tScorers, scoreMode);
192+
return new HybridQueryScorer(tScorers, scoreMode);
180193
}
181194

182195
@Override

0 commit comments

Comments
 (0)