Skip to content

add validation for invalid nested hybrid query #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Bug Fixes
- Fix score value as null for single shard when sorting is not done on score field ([#1277](https://github.com/opensearch-project/neural-search/pull/1277))
- Return bad request for stats API calls with invalid stat names instead of ignoring them ([#1291](https://github.com/opensearch-project/neural-search/pull/1291))

- Add validation for invalid nested hybrid query ([#1305](https://github.com/opensearch-project/neural-search/pull/1305))
### Infrastructure

### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public boolean searchWith(
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
Query hybridQuery = extractHybridQuery(searchContext, query);
validateHybridQuery((HybridQuery) hybridQuery);
QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext);
queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
// we decide on rescore later in collector manager
Expand Down Expand Up @@ -155,6 +156,14 @@ private void validateNestedDisJunctionQuery(final Query query, final int level)
}
}

private void validateHybridQuery(final HybridQuery query) {
for (Query innerQuery : query) {
if (innerQuery instanceof HybridQuery) {
throw new IllegalArgumentException("hybrid query cannot be nested in another hybrid query");
}
}
}

private int getMaxDepthLimit(final SearchContext searchContext) {
Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,103 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() {
releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridNestedInHybrid_thenFail() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
MapperService mapperService = mock(MapperService.class);
when(mapperService.hasNested()).thenReturn(false);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(indexShard.getSearchOperationListener()).thenReturn(mock(SearchOperationListener.class));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);
when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext);
IndexMetadata indexMetadata = getIndexMetadata();
Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build();
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();
queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));
queryBuilder.paginationDepth(10);

HybridQueryBuilder nestedQueryBuilder = new HybridQueryBuilder();
nestedQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
nestedQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));
nestedQueryBuilder.paginationDepth(10);
queryBuilder.add(nestedQueryBuilder);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> hybridQueryPhaseSearcher.searchWith(
searchContext,
contextIndexSearcher,
query,
collectors,
hasFilterCollector,
hasTimeout
)
);

org.hamcrest.MatcherAssert.assertThat(
exception.getMessage(),
containsString("hybrid query cannot be nested in another hybrid query")
);

releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridNestedInDisjunctionQuery_thenFail() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
Expand Down
Loading