Skip to content

Commit c644cde

Browse files
martin-gaievskiYeonghyeonKO
authored andcommitted
[Performance Improvement] Add custom bulk scorer for hybrid query (2-3x faster) (opensearch-project#1289)
Signed-off-by: Martin Gaievski <[email protected]> Signed-off-by: yeonghyeonKo <[email protected]>
1 parent 6c5bef7 commit c644cde

23 files changed

+1402
-455
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
88
### Features
99

1010
### Enhancements
11+
- [Performance Improvement] Add custom bulk scorer for hybrid query (2-3x faster) ([#1289](https://github.com/opensearch-project/neural-search/pull/1289))
1112

1213
### Bug Fixes
1314

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.bwc.rolling;
6+
7+
import org.opensearch.common.Randomness;
8+
import org.opensearch.index.query.MatchQueryBuilder;
9+
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
10+
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
11+
12+
import java.nio.file.Files;
13+
import java.nio.file.Path;
14+
import java.util.Arrays;
15+
import java.util.HashSet;
16+
import java.util.List;
17+
import java.util.Locale;
18+
import java.util.Map;
19+
import java.util.Random;
20+
import java.util.Set;
21+
22+
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
23+
import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR;
24+
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
25+
26+
public class HybridSearchRelevancyIT extends AbstractRollingUpgradeTestCase {
27+
private static final String PIPELINE_NAME = "neural-pipeline";
28+
private static final String SEARCH_PIPELINE_NAME = "hybrid-search-pipeline";
29+
private static final String TEST_FIELD = "passage_text";
30+
private static final int NUM_DOCS = 100;
31+
private static final String VECTOR_EMBEDDING_FIELD = "passage_embedding";
32+
private static final String QUERY_TEXT = "machine learning patterns";
33+
private String modelId;
34+
35+
// Arrays of words to generate random meaningful content
36+
private static final String[] SUBJECTS = {
37+
"Machine learning",
38+
"Deep learning",
39+
"Neural networks",
40+
"Artificial intelligence",
41+
"Data science",
42+
"Natural language processing",
43+
"Computer vision",
44+
"Robotics",
45+
"Big data",
46+
"Cloud computing",
47+
"Edge computing",
48+
"Internet of Things" };
49+
50+
private static final String[] VERBS = {
51+
"analyzes",
52+
"processes",
53+
"transforms",
54+
"improves",
55+
"optimizes",
56+
"enhances",
57+
"revolutionizes",
58+
"accelerates",
59+
"streamlines",
60+
"powers",
61+
"enables",
62+
"drives" };
63+
64+
private static final String[] OBJECTS = {
65+
"data processing",
66+
"pattern recognition",
67+
"decision making",
68+
"business operations",
69+
"computational tasks",
70+
"system performance",
71+
"automation processes",
72+
"data analysis",
73+
"resource utilization",
74+
"technological innovation",
75+
"software development",
76+
"cloud infrastructure" };
77+
78+
private static final String[] MODIFIERS = {
79+
"efficiently",
80+
"rapidly",
81+
"intelligently",
82+
"automatically",
83+
"significantly",
84+
"dramatically",
85+
"consistently",
86+
"reliably",
87+
"effectively",
88+
"seamlessly" };
89+
90+
public void testSearchHitsAfterNormalization_whenIndexWithMultipleShards_E2EFlow() throws Exception {
91+
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
92+
String indexName = getIndexNameForTest();
93+
String[] testDocuments = generateTestDocuments(NUM_DOCS);
94+
switch (getClusterType()) {
95+
case OLD:
96+
modelId = uploadTextEmbeddingModel();
97+
loadModel(modelId);
98+
createPipelineProcessor(modelId, PIPELINE_NAME);
99+
createIndexWithConfiguration(
100+
indexName,
101+
Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())),
102+
PIPELINE_NAME
103+
);
104+
// ingest test documents
105+
for (int i = 0; i < testDocuments.length; i++) {
106+
addDocument(indexName, String.valueOf(i), TEST_FIELD, testDocuments[i], null, null);
107+
}
108+
createSearchPipeline(
109+
SEARCH_PIPELINE_NAME,
110+
"l2",
111+
"arithmetic_mean",
112+
Map.of("weights", Arrays.toString(new float[] { 0.5f, 0.5f }))
113+
);
114+
115+
// execute hybrid query and store results
116+
HybridQueryBuilder hybridQueryBuilder = createHybridQuery(modelId, QUERY_TEXT);
117+
getAndAssertQueryResults(hybridQueryBuilder, modelId, NUM_DOCS);
118+
break;
119+
case MIXED:
120+
modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR);
121+
HybridQueryBuilder mixedClusterQuery = createHybridQuery(modelId, QUERY_TEXT);
122+
if (isFirstMixedRound()) {
123+
getAndAssertQueryResults(mixedClusterQuery, modelId, NUM_DOCS);
124+
String[] testDocumentsAfterMixedUpgrade = generateTestDocuments(NUM_DOCS);
125+
for (int i = 0; i < testDocumentsAfterMixedUpgrade.length; i++) {
126+
addDocument(indexName, String.valueOf(NUM_DOCS + i), TEST_FIELD, testDocumentsAfterMixedUpgrade[i], null, null);
127+
}
128+
} else {
129+
getAndAssertQueryResults(mixedClusterQuery, modelId, 2 * NUM_DOCS);
130+
}
131+
break;
132+
case UPGRADED:
133+
try {
134+
modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR);
135+
loadModel(modelId);
136+
String[] testDocumentsAfterFullUpgrade = generateTestDocuments(NUM_DOCS);
137+
for (int i = 0; i < testDocumentsAfterFullUpgrade.length; i++) {
138+
addDocument(indexName, String.valueOf(2 * NUM_DOCS + i), TEST_FIELD, testDocumentsAfterFullUpgrade[i], null, null);
139+
}
140+
HybridQueryBuilder upgradedClusterQuery = createHybridQuery(modelId, QUERY_TEXT);
141+
getAndAssertQueryResults(upgradedClusterQuery, modelId, 3 * NUM_DOCS);
142+
} finally {
143+
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
144+
}
145+
break;
146+
default:
147+
throw new IllegalStateException(String.format(Locale.ROOT, "Unexpected value: %s", getClusterType()));
148+
}
149+
}
150+
151+
private String[] generateTestDocuments(int count) {
152+
String[] documents = new String[count];
153+
Random random = Randomness.get();
154+
155+
for (int i = 0; i < count; i++) {
156+
String subject = SUBJECTS[random.nextInt(SUBJECTS.length)];
157+
String verb = VERBS[random.nextInt(VERBS.length)];
158+
String object = OBJECTS[random.nextInt(OBJECTS.length)];
159+
String modifier = MODIFIERS[random.nextInt(MODIFIERS.length)];
160+
161+
// randomly decide whether to add a modifier (70% chance)
162+
boolean includeModifier = random.nextDouble() < 0.7;
163+
164+
documents[i] = includeModifier
165+
? String.format(Locale.ROOT, "%s %s %s %s", subject, verb, object, modifier)
166+
: String.format(Locale.ROOT, "%s %s %s", subject, verb, object);
167+
}
168+
return documents;
169+
}
170+
171+
private HybridQueryBuilder createHybridQuery(String modelId, String queryText) {
172+
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
173+
.fieldName(VECTOR_EMBEDDING_FIELD)
174+
.modelId(modelId)
175+
.queryText(queryText)
176+
.k(10 * NUM_DOCS)
177+
.build();
178+
179+
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", queryText);
180+
181+
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
182+
hybridQueryBuilder.add(matchQueryBuilder);
183+
hybridQueryBuilder.add(neuralQueryBuilder);
184+
185+
return hybridQueryBuilder;
186+
}
187+
188+
private void getAndAssertQueryResults(HybridQueryBuilder queryBuilder, String modelId, int queryResultSize) throws Exception {
189+
loadModel(modelId);
190+
Map<String, Object> searchResponseAsMap = search(
191+
getIndexNameForTest(),
192+
queryBuilder,
193+
null,
194+
queryResultSize,
195+
Map.of("search_pipeline", SEARCH_PIPELINE_NAME)
196+
);
197+
int hits = getHitCount(searchResponseAsMap);
198+
assertEquals(queryResultSize, hits);
199+
200+
List<Double> normalizedScores = getNormalizationScoreList(searchResponseAsMap);
201+
assertQueryScores(normalizedScores, queryResultSize);
202+
List<String> normalizedDocIds = getNormalizationDocIdList(searchResponseAsMap);
203+
assertQueryDocIds(normalizedDocIds, queryResultSize);
204+
}
205+
206+
private void assertQueryScores(List<Double> queryScores, int queryResultSize) {
207+
assertNotNull(queryScores);
208+
assertEquals(queryResultSize, queryScores.size());
209+
210+
// check scores are in descending order
211+
for (int i = 0; i < queryScores.size() - 1; i++) {
212+
double currentScore = queryScores.get(i);
213+
double nextScore = queryScores.get(i + 1);
214+
assertTrue("scores not in descending order", currentScore >= nextScore);
215+
}
216+
}
217+
218+
private void assertQueryDocIds(List<String> querDocIds, int queryResultSize) {
219+
assertNotNull(querDocIds);
220+
assertEquals(queryResultSize, querDocIds.size());
221+
222+
// check document IDs are unique
223+
Set<String> uniqueDocIds = new HashSet<>();
224+
for (String docId : querDocIds) {
225+
assertTrue("duplicate document ID found", uniqueDocIds.add(docId));
226+
}
227+
assertEquals("number of unique document IDs doesn't match expected count", queryResultSize, uniqueDocIds.size());
228+
}
229+
}

0 commit comments

Comments
 (0)