Skip to content

Commit 37aab93

Browse files
will-hwangYeonghyeonKO
authored andcommitted
add validation for invalid nested hybrid query (opensearch-project#1305)
* add validation for nested hybrid query Signed-off-by: will-hwang <[email protected]> Signed-off-by: yeonghyeonKo <[email protected]>
1 parent c9cc3b7 commit 37aab93

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
### Bug Fixes
2020
- Fix score value as null for single shard when sorting is not done on score field ([#1277](https://github.com/opensearch-project/neural-search/pull/1277))
2121
- Return bad request for stats API calls with invalid stat names instead of ignoring them ([#1291](https://github.com/opensearch-project/neural-search/pull/1291))
22-
22+
- Add validation for invalid nested hybrid query ([#1305](https://github.com/opensearch-project/neural-search/pull/1305))
2323
### Infrastructure
2424

2525
### Documentation

src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public boolean searchWith(
6161
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
6262
} else {
6363
Query hybridQuery = extractHybridQuery(searchContext, query);
64+
validateHybridQuery((HybridQuery) hybridQuery);
6465
QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext);
6566
queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
6667
// we decide on rescore later in collector manager
@@ -155,6 +156,14 @@ private void validateNestedDisJunctionQuery(final Query query, final int level)
155156
}
156157
}
157158

159+
private void validateHybridQuery(final HybridQuery query) {
160+
for (Query innerQuery : query) {
161+
if (innerQuery instanceof HybridQuery) {
162+
throw new IllegalArgumentException("hybrid query cannot be nested in another hybrid query");
163+
}
164+
}
165+
}
166+
158167
private int getMaxDepthLimit(final SearchContext searchContext) {
159168
Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
160169
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue();

src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,103 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() {
518518
releaseResources(directory, w, reader);
519519
}
520520

521+
@SneakyThrows
522+
public void testWrappedHybridQuery_whenHybridNestedInHybrid_thenFail() {
523+
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
524+
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
525+
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
526+
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
527+
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
528+
MapperService mapperService = mock(MapperService.class);
529+
when(mapperService.hasNested()).thenReturn(false);
530+
531+
Directory directory = newDirectory();
532+
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
533+
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
534+
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
535+
ft.setOmitNorms(random().nextBoolean());
536+
ft.freeze();
537+
int docId1 = RandomizedTest.randomInt();
538+
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
539+
w.commit();
540+
541+
IndexReader reader = DirectoryReader.open(w);
542+
SearchContext searchContext = mock(SearchContext.class);
543+
544+
ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
545+
reader,
546+
IndexSearcher.getDefaultSimilarity(),
547+
IndexSearcher.getDefaultQueryCache(),
548+
IndexSearcher.getDefaultQueryCachingPolicy(),
549+
true,
550+
null,
551+
searchContext
552+
);
553+
554+
ShardId shardId = new ShardId(dummyIndex, 1);
555+
SearchShardTarget shardTarget = new SearchShardTarget(
556+
randomAlphaOfLength(10),
557+
shardId,
558+
randomAlphaOfLength(10),
559+
OriginalIndices.NONE
560+
);
561+
when(searchContext.shardTarget()).thenReturn(shardTarget);
562+
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
563+
when(searchContext.size()).thenReturn(4);
564+
QuerySearchResult querySearchResult = new QuerySearchResult();
565+
when(searchContext.queryResult()).thenReturn(querySearchResult);
566+
when(searchContext.numberOfShards()).thenReturn(1);
567+
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
568+
IndexShard indexShard = mock(IndexShard.class);
569+
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
570+
when(indexShard.getSearchOperationListener()).thenReturn(mock(SearchOperationListener.class));
571+
when(searchContext.indexShard()).thenReturn(indexShard);
572+
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
573+
when(searchContext.mapperService()).thenReturn(mapperService);
574+
when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext);
575+
IndexMetadata indexMetadata = getIndexMetadata();
576+
Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build();
577+
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);
578+
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);
579+
580+
LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
581+
boolean hasFilterCollector = randomBoolean();
582+
boolean hasTimeout = randomBoolean();
583+
584+
HybridQueryBuilder queryBuilder = new HybridQueryBuilder();
585+
queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
586+
queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));
587+
queryBuilder.paginationDepth(10);
588+
589+
HybridQueryBuilder nestedQueryBuilder = new HybridQueryBuilder();
590+
nestedQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
591+
nestedQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));
592+
nestedQueryBuilder.paginationDepth(10);
593+
queryBuilder.add(nestedQueryBuilder);
594+
595+
Query query = queryBuilder.toQuery(mockQueryShardContext);
596+
when(searchContext.query()).thenReturn(query);
597+
598+
IllegalArgumentException exception = expectThrows(
599+
IllegalArgumentException.class,
600+
() -> hybridQueryPhaseSearcher.searchWith(
601+
searchContext,
602+
contextIndexSearcher,
603+
query,
604+
collectors,
605+
hasFilterCollector,
606+
hasTimeout
607+
)
608+
);
609+
610+
org.hamcrest.MatcherAssert.assertThat(
611+
exception.getMessage(),
612+
containsString("hybrid query cannot be nested in another hybrid query")
613+
);
614+
615+
releaseResources(directory, w, reader);
616+
}
617+
521618
@SneakyThrows
522619
public void testWrappedHybridQuery_whenHybridNestedInDisjunctionQuery_thenFail() {
523620
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();

0 commit comments

Comments
 (0)