Skip to content

Commit 64fc42a

Browse files
Added custom bulk scorer for hybrid query
Signed-off-by: Martin Gaievski <[email protected]>
1 parent d42efb1 commit 64fc42a

16 files changed

+864
-355
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.Getter;
8+
import org.apache.lucene.search.BulkScorer;
9+
import org.apache.lucene.search.DocIdSetIterator;
10+
import org.apache.lucene.search.LeafCollector;
11+
import org.apache.lucene.search.Scorer;
12+
import org.apache.lucene.util.Bits;
13+
import org.apache.lucene.util.FixedBitSet;
14+
15+
import java.io.IOException;
16+
import java.util.Arrays;
17+
import java.util.List;
18+
import java.util.Objects;
19+
20+
/**
21+
* Bulk scorer for hybrid query
22+
*/
23+
public class HybridBulkScorer extends BulkScorer {
24+
private static final int SHIFT = 10;
25+
private static final int WINDOW_SIZE = 1 << SHIFT;
26+
private static final int MASK = WINDOW_SIZE - 1;
27+
28+
private final long cost;
29+
private final Scorer[] disiWrappers;
30+
@Getter
31+
private final HybridSubQueryScorer hybridSubQueryScorer;
32+
private final boolean needsScores;
33+
@Getter
34+
private final FixedBitSet matching;
35+
@Getter
36+
private final float[][] windowScores;
37+
private final HybridQueryDocIdStream hybridQueryDocIdStream;
38+
private final int maxDoc;
39+
40+
public HybridBulkScorer(List<Scorer> scorers, boolean needsScores, int maxDoc) {
41+
long cost = 0;
42+
this.disiWrappers = new Scorer[scorers.size()];
43+
for (int subQueryIndex = 0; subQueryIndex < scorers.size(); subQueryIndex++) {
44+
Scorer scorer = scorers.get(subQueryIndex);
45+
if (Objects.isNull(scorer)) {
46+
continue;
47+
}
48+
cost += scorer.iterator().cost();
49+
disiWrappers[subQueryIndex] = scorer;
50+
}
51+
this.cost = cost;
52+
this.hybridSubQueryScorer = new HybridSubQueryScorer(scorers.size());
53+
this.needsScores = needsScores;
54+
this.matching = new FixedBitSet(WINDOW_SIZE);
55+
this.windowScores = new float[disiWrappers.length][WINDOW_SIZE];
56+
this.maxDoc = maxDoc;
57+
this.hybridQueryDocIdStream = new HybridQueryDocIdStream(this);
58+
}
59+
60+
@Override
61+
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
62+
collector.setScorer(hybridSubQueryScorer);
63+
// making sure we are not going over the global limit defined by maxDoc
64+
max = Math.min(max, maxDoc);
65+
int[] docsIds = advance(min, disiWrappers);
66+
while (allDocIdsUsed(docsIds, max) == false) {
67+
scoreWindow(collector, acceptDocs, min, max, docsIds);
68+
}
69+
return getNextDocIdCandidate(docsIds);
70+
}
71+
72+
private void scoreWindow(LeafCollector collector, Bits acceptDocs, int min, int max, int[] docsIds) throws IOException {
73+
// pick the lowest out of all not yet used doc ids
74+
int topDoc = -1;
75+
for (int docsId : docsIds) {
76+
if (docsId < max) {
77+
topDoc = docsId;
78+
break;
79+
}
80+
}
81+
82+
final int windowBase = topDoc & ~MASK; // take the next match (at random) and find the window where it belongs
83+
final int windowMin = Math.max(min, windowBase);
84+
final int windowMax = Math.min(max, windowBase + WINDOW_SIZE);
85+
86+
scoreWindowIntoBitSetWithSubqueryScorers(collector, acceptDocs, max, docsIds, windowMin, windowMax, windowBase);
87+
}
88+
89+
private void scoreWindowIntoBitSetWithSubqueryScorers(
90+
LeafCollector collector,
91+
Bits acceptDocs,
92+
int max,
93+
int[] docsIds,
94+
int windowMin,
95+
int windowMax,
96+
int windowBase
97+
) throws IOException {
98+
for (int i = 0; i < disiWrappers.length; i++) {
99+
if (disiWrappers[i] == null || docsIds[i] >= max) {
100+
continue;
101+
}
102+
DocIdSetIterator it = disiWrappers[i].iterator();
103+
int doc = docsIds[i];
104+
if (doc < windowMin) {
105+
doc = it.advance(windowMin);
106+
}
107+
for (; doc < windowMax; doc = it.nextDoc()) {
108+
if (Objects.isNull(acceptDocs) || acceptDocs.get(doc)) {
109+
int d = doc & MASK;
110+
if (needsScores) {
111+
float score = disiWrappers[i].score();
112+
if (score > hybridSubQueryScorer.getMinScores()[i]) {
113+
matching.set(d);
114+
windowScores[i][d] = score;
115+
}
116+
} else {
117+
matching.set(d);
118+
}
119+
}
120+
}
121+
docsIds[i] = doc;
122+
}
123+
124+
hybridQueryDocIdStream.setBase(windowBase);
125+
collector.collect(hybridQueryDocIdStream);
126+
127+
matching.clear();
128+
129+
for (float[] windowScore : windowScores) {
130+
Arrays.fill(windowScore, 0.0f);
131+
}
132+
}
133+
134+
private int[] advance(int min, Scorer[] scorers) throws IOException {
135+
int[] docIds = new int[scorers.length];
136+
for (int i = 0; i < scorers.length; i++) {
137+
if (scorers[i] == null) {
138+
docIds[i] = DocIdSetIterator.NO_MORE_DOCS;
139+
continue;
140+
}
141+
DocIdSetIterator it = scorers[i].iterator();
142+
int doc = it.docID();
143+
if (doc < min) {
144+
doc = it.advance(min);
145+
}
146+
docIds[i] = doc;
147+
}
148+
return docIds;
149+
}
150+
151+
private boolean allDocIdsUsed(int[] docsIds, int max) {
152+
for (int docId : docsIds) {
153+
if (docId < max) {
154+
return false;
155+
}
156+
}
157+
return true;
158+
}
159+
160+
private int getNextDocIdCandidate(final int[] docsIds) {
161+
int nextDoc = -1;
162+
for (int doc : docsIds) {
163+
if (doc != DocIdSetIterator.NO_MORE_DOCS) {
164+
nextDoc = Math.max(nextDoc, doc);
165+
}
166+
}
167+
return nextDoc == -1 ? DocIdSetIterator.NO_MORE_DOCS : nextDoc;
168+
}
169+
170+
@Override
171+
public long cost() {
172+
return cost;
173+
}
174+
175+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.RequiredArgsConstructor;
8+
import lombok.Setter;
9+
import org.apache.lucene.search.CheckedIntConsumer;
10+
import org.apache.lucene.search.DocIdStream;
11+
import org.apache.lucene.util.FixedBitSet;
12+
13+
import java.io.IOException;
14+
import java.util.Objects;
15+
16+
/**
17+
* This class is used to create a DocIdStream for HybridQuery
18+
*/
19+
@RequiredArgsConstructor
20+
public class HybridQueryDocIdStream extends DocIdStream {
21+
private final HybridBulkScorer hybridBulkScorer;
22+
@Setter
23+
private int base;
24+
25+
/**
26+
* Iterate over all doc ids and collect each doc id with leaf collector
27+
* @param consumer consumer that is called for each accepted doc id
28+
* @throws IOException in case of IO exception
29+
*/
30+
@Override
31+
public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
32+
// bitset that represents matching documents, bit is set (1) if doc id is a match
33+
FixedBitSet matchingBitSet = hybridBulkScorer.getMatching();
34+
long[] bitArray = matchingBitSet.getBits();
35+
// iterate through each block of 64 documents (since each long contains 64 bits)
36+
for (int idx = 0; idx < bitArray.length; idx++) {
37+
long bits = bitArray[idx];
38+
while (bits != 0L) {
39+
// find position of the rightmost set bit (1)
40+
int numberOfTrailingZeros = Long.numberOfTrailingZeros(bits);
41+
// calculate actual document ID within the window
42+
// idx << 6 is equivalent to idx * 64 (block offset)
43+
// numberOfTrailingZeros gives position within the block
44+
final int docIndexInWindow = (idx << 6) | numberOfTrailingZeros;
45+
float[][] windowScores = hybridBulkScorer.getWindowScores();
46+
for (int subQueryIndex = 0; subQueryIndex < windowScores.length; subQueryIndex++) {
47+
if (Objects.isNull(windowScores[subQueryIndex])) {
48+
continue;
49+
}
50+
float scoreOfDocIdForSubQuery = windowScores[subQueryIndex][docIndexInWindow];
51+
hybridBulkScorer.getHybridSubQueryScorer().getSubQueryScores()[subQueryIndex] = scoreOfDocIdForSubQuery;
52+
}
53+
// process the document with its base offset
54+
consumer.accept(base | docIndexInWindow);
55+
// reset scores after processing of one doc, this is required because scorer object is re-used
56+
hybridBulkScorer.getHybridSubQueryScorer().resetScores();
57+
// reset bit for this doc id to indicate that it has been consumed
58+
bits ^= 1L << numberOfTrailingZeros;
59+
}
60+
}
61+
}
62+
}

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) thr
7575
sumMatchCost += w.matchCost * costWeight;
7676
}
7777
}
78-
if (!hasApproximation) { // no sub scorer supports approximations
78+
if (hasApproximation == false) { // no sub scorer supports approximations
7979
twoPhase = null;
8080
} else {
8181
final float matchCost = sumMatchCost / sumApproxCost;
@@ -284,7 +284,7 @@ public boolean matches() throws IOException {
284284
wrapper.next = verifiedMatches;
285285
verifiedMatches = wrapper;
286286

287-
if (!needsScores) {
287+
if (needsScores == false) {
288288
// we can stop here
289289
return true;
290290
}

src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import lombok.Getter;
1616
import lombok.RequiredArgsConstructor;
1717
import org.apache.lucene.index.LeafReaderContext;
18+
import org.apache.lucene.search.BulkScorer;
1819
import org.apache.lucene.search.Explanation;
1920
import org.apache.lucene.search.IndexSearcher;
2021
import org.apache.lucene.search.Matches;
@@ -37,7 +38,6 @@ public final class HybridQueryWeight extends Weight {
3738
// The Weights for our subqueries, in 1-1 correspondence
3839
@Getter(AccessLevel.PACKAGE)
3940
private final List<Weight> weights;
40-
4141
private final ScoreMode scoreMode;
4242

4343
/**
@@ -95,7 +95,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
9595
if (scorerSuppliers.isEmpty()) {
9696
return null;
9797
}
98-
return new HybridScorerSupplier(scorerSuppliers, this, scoreMode);
98+
return new HybridScorerSupplier(scorerSuppliers, this, scoreMode, context);
9999
}
100100

101101
private Void addScoreSupplier(Weight weight, HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier> collector) {
@@ -145,7 +145,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
145145
max = Math.max(max, score);
146146
subsOnMatch.add(e);
147147
} else {
148-
if (!match) {
148+
if (match == false) {
149149
subsOnNoMatch.add(e);
150150
}
151151
subsOnMatch.add(e);
@@ -161,10 +161,23 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
161161

162162
@RequiredArgsConstructor
163163
static class HybridScorerSupplier extends ScorerSupplier {
164+
165+
@Override
166+
public BulkScorer bulkScorer() throws IOException {
167+
List<Scorer> scorers = new ArrayList<>();
168+
for (Weight weight : weight.getWeights()) {
169+
Scorer scorer = weight.scorer(context);
170+
scorers.add(scorer);
171+
}
172+
return new HybridBulkScorer(scorers, scoreMode.needsScores(), context.reader().maxDoc());
173+
}
174+
164175
private long cost = -1;
176+
@Getter
165177
private final List<ScorerSupplier> scorerSuppliers;
166-
private final Weight weight;
178+
private final HybridQueryWeight weight;
167179
private final ScoreMode scoreMode;
180+
private final LeafReaderContext context;
168181

169182
@Override
170183
public Scorer get(long leadCost) throws IOException {
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.Data;
8+
import org.apache.lucene.search.Scorable;
9+
10+
import java.io.IOException;
11+
import java.util.Arrays;
12+
13+
/**
14+
* Scorer implementation for Hybrid Query. This object is light and expected to be re-used between different doc ids
15+
*/
16+
@Data
17+
public class HybridSubQueryScorer extends Scorable {
18+
// array of scores from all sub-queries for a single doc id
19+
private final float[] subQueryScores;
20+
// array of min competitive scores, score is shard level
21+
private final float[] minScores;
22+
23+
public HybridSubQueryScorer(int numOfSubQueries) {
24+
this.minScores = new float[numOfSubQueries];
25+
this.subQueryScores = new float[numOfSubQueries];
26+
}
27+
28+
@Override
29+
public float score() throws IOException {
30+
// for scenarios when scorer is needed (like in aggregations) for one doc id return sum of sub-query scores
31+
float totalScore = 0.0f;
32+
for (float score : subQueryScores) {
33+
totalScore += score;
34+
}
35+
return totalScore;
36+
}
37+
38+
/**
39+
* Reset sub-query scores to 0.0f so this scorer can be reused for next doc id
40+
*/
41+
public void resetScores() {
42+
Arrays.fill(subQueryScores, 0.0f);
43+
}
44+
45+
public int getNumOfSubQueries() {
46+
return subQueryScores.length;
47+
}
48+
}

0 commit comments

Comments
 (0)