Skip to content

Commit 65a18f3

Browse files
Addressed review comments, mainly refactoring
Signed-off-by: Martin Gaievski <[email protected]>
1 parent 90aad85 commit 65a18f3

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

src/main/java/org/opensearch/neuralsearch/query/HybridBulkScorer.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public class HybridBulkScorer extends BulkScorer {
3636
private final float[][] windowScores;
3737
private final HybridQueryDocIdStream hybridQueryDocIdStream;
3838
private final int maxDoc;
39+
private int[] docIds;
3940

4041
/**
4142
* Constructor for HybridBulkScorer
@@ -45,8 +46,9 @@ public class HybridBulkScorer extends BulkScorer {
4546
*/
4647
public HybridBulkScorer(List<Scorer> scorers, boolean needsScores, int maxDoc) {
4748
long cost = 0;
48-
this.scorers = new Scorer[scorers.size()];
49-
for (int subQueryIndex = 0; subQueryIndex < scorers.size(); subQueryIndex++) {
49+
int numOfQueries = scorers.size();
50+
this.scorers = new Scorer[numOfQueries];
51+
for (int subQueryIndex = 0; subQueryIndex < numOfQueries; subQueryIndex++) {
5052
Scorer scorer = scorers.get(subQueryIndex);
5153
if (Objects.isNull(scorer)) {
5254
continue;
@@ -55,12 +57,14 @@ public HybridBulkScorer(List<Scorer> scorers, boolean needsScores, int maxDoc) {
5557
this.scorers[subQueryIndex] = scorer;
5658
}
5759
this.cost = cost;
58-
this.hybridSubQueryScorer = new HybridSubQueryScorer(scorers.size());
60+
this.hybridSubQueryScorer = new HybridSubQueryScorer(numOfQueries);
5961
this.needsScores = needsScores;
6062
this.matching = new FixedBitSet(WINDOW_SIZE);
6163
this.windowScores = new float[this.scorers.length][WINDOW_SIZE];
6264
this.maxDoc = maxDoc;
6365
this.hybridQueryDocIdStream = new HybridQueryDocIdStream(this);
66+
this.docIds = new int[numOfQueries];
67+
Arrays.fill(docIds, DocIdSetIterator.NO_MORE_DOCS);
6468
}
6569

6670
@Override
@@ -69,15 +73,15 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr
6973
// making sure we are not going over the global limit defined by maxDoc
7074
max = Math.min(max, maxDoc);
7175
// advance all scorers to the segment's minimum doc id
72-
int[] docsIds = advance(min, scorers);
73-
while (allDocIdsUsed(docsIds, max) == false) {
74-
scoreWindow(collector, acceptDocs, min, max, docsIds);
76+
advance(min, scorers);
77+
while (allDocIdsUsed(docIds, max) == false) {
78+
scoreWindow(collector, acceptDocs, min, max, docIds);
7579
}
76-
return getNextDocIdCandidate(docsIds);
80+
return getNextDocIdCandidate(docIds);
7781
}
7882

7983
private void scoreWindow(LeafCollector collector, Bits acceptDocs, int min, int max, int[] docIds) throws IOException {
80-
// pick the lowest out of all not yet used doc ids
84+
// find the first document ID below the maximum threshold to establish the next scoring window boundary
8185
int topDoc = -1;
8286
for (int docId : docIds) {
8387
if (docId < max) {
@@ -150,11 +154,9 @@ private void scoreWindowIntoBitSetWithSubqueryScorers(
150154
/**
151155
* Advance all scorers to the next document that is >= min
152156
*/
153-
private int[] advance(int min, Scorer[] scorers) throws IOException {
154-
int[] docIds = new int[scorers.length];
157+
private void advance(int min, Scorer[] scorers) throws IOException {
155158
for (int subQueryIndex = 0; subQueryIndex < scorers.length; subQueryIndex++) {
156159
if (Objects.isNull(scorers[subQueryIndex])) {
157-
docIds[subQueryIndex] = DocIdSetIterator.NO_MORE_DOCS;
158160
continue;
159161
}
160162
DocIdSetIterator it = scorers[subQueryIndex].iterator();
@@ -164,7 +166,6 @@ private int[] advance(int min, Scorer[] scorers) throws IOException {
164166
}
165167
docIds[subQueryIndex] = doc;
166168
}
167-
return docIds;
168169
}
169170

170171
private boolean allDocIdsUsed(int[] docsIds, int max) {

src/main/java/org/opensearch/neuralsearch/query/HybridQueryDocIdStream.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
@RequiredArgsConstructor
2020
public class HybridQueryDocIdStream extends DocIdStream {
21+
private static final int BLOCK_SHIFT = 6;
2122
private final HybridBulkScorer hybridBulkScorer;
2223
@Setter
2324
private int base;
@@ -41,7 +42,7 @@ public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException
4142
// calculate actual document ID within the window
4243
// idx << 6 is equivalent to idx * 64 (block offset)
4344
// numberOfTrailingZeros gives position within the block
44-
final int docIndexInWindow = (idx << 6) | numberOfTrailingZeros;
45+
final int docIndexInWindow = (idx << BLOCK_SHIFT) | numberOfTrailingZeros;
4546
float[][] windowScores = hybridBulkScorer.getWindowScores();
4647
for (int subQueryIndex = 0; subQueryIndex < windowScores.length; subQueryIndex++) {
4748
if (Objects.isNull(windowScores[subQueryIndex])) {

src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.util.PriorityQueue;
2424

2525
import lombok.extern.log4j.Log4j2;
26+
import org.opensearch.neuralsearch.query.HybridSubQueryScorer;
2627
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
2728

2829
/**
@@ -132,22 +133,18 @@ public void setScorer(Scorable scorer) throws IOException {
132133

133134
@Override
134135
public void collect(int doc) throws IOException {
135-
if (Objects.isNull(getCompoundQueryScorer())) {
136+
HybridSubQueryScorer compoundQueryScorer = getCompoundQueryScorer();
137+
if (Objects.isNull(compoundQueryScorer)) {
136138
return;
137139
}
138140
ensureSubQueryScoreQueues();
139141
// Increment total hit count which represents unique doc found on the shard
140142
totalHits++;
141-
float[] scores = getCompoundQueryScorer().getSubQueryScores();
143+
float[] scores = compoundQueryScorer.getSubQueryScores();
142144
int docWithBase = doc + docBase;
143145
for (int subQueryIndex = 0; subQueryIndex < scores.length; subQueryIndex++) {
144146
float score = scores[subQueryIndex];
145-
// if score is 0.0 there is no hits for that sub-query
146-
if (score <= 0) {
147-
continue;
148-
}
149-
150-
if (score < minScoreThresholds[subQueryIndex]) {
147+
if (isNonCompetitiveScore(score, subQueryIndex)) {
151148
continue;
152149
}
153150

@@ -172,6 +169,10 @@ public void collect(int doc) throws IOException {
172169
}
173170
}
174171

172+
private boolean isNonCompetitiveScore(float score, int subQueryIndex) {
173+
return score <= 0 && score < minScoreThresholds[subQueryIndex];
174+
}
175+
175176
/**
176177
* Initialize compound score queues for each sub query if it's not initialized already
177178
*/

0 commit comments

Comments
 (0)