Skip to content

Commit 390d04c

Browse files
committed
After further review round
Signed-off-by: br3no <[email protected]>
1 parent a5977b2 commit 390d04c

15 files changed

+504
-326
lines changed

src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

Lines changed: 252 additions & 178 deletions
Large diffs are not rendered by default.

src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.opensearch.env.Environment;
1515
import org.opensearch.ingest.IngestDocument;
1616
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
17+
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
1718
import org.opensearch.neuralsearch.util.TokenWeightUtil;
1819

1920
import lombok.extern.log4j.Log4j2;
@@ -48,17 +49,19 @@ public void doExecute(
4849
List<String> inferenceList,
4950
BiConsumer<IngestDocument, Exception> handler
5051
) {
51-
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
52-
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
53-
handler.accept(ingestDocument, null);
54-
}, e -> { handler.accept(null, e); }));
52+
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
53+
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(),
54+
ActionListener.wrap(resultMaps -> {
55+
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
56+
handler.accept(ingestDocument, null);
57+
}, e -> { handler.accept(null, e); })
58+
);
5559
}
5660

5761
@Override
5862
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
5963
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
60-
this.modelId,
61-
inferenceList,
64+
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(),
6265
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
6366
);
6467
}

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,23 @@
1818
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
1919

2020
import lombok.extern.log4j.Log4j2;
21+
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
2122

2223
/**
23-
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
24-
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
24+
* This processor is used for user input data text embedding processing, model_id can be used to
25+
* indicate which model user use, and field_map can be used to indicate which fields needs text
26+
* embedding and the corresponding keys for the text embedding results.
2527
*/
2628
@Log4j2
2729
public final class TextEmbeddingProcessor extends InferenceProcessor {
2830

2931
public static final String TYPE = "text_embedding";
3032
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
3133

34+
private static final AsymmetricTextEmbeddingParameters PASSAGE_PARAMETERS = AsymmetricTextEmbeddingParameters.builder()
35+
.embeddingContentType(EmbeddingContentType.PASSAGE)
36+
.build();
37+
3238
public TextEmbeddingProcessor(
3339
String tag,
3440
String description,
@@ -50,9 +56,7 @@ public void doExecute(
5056
BiConsumer<IngestDocument, Exception> handler
5157
) {
5258
mlCommonsClientAccessor.inferenceSentences(
53-
this.modelId,
54-
inferenceList,
55-
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
59+
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
5660
ActionListener.wrap(vectors -> {
5761
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
5862
handler.accept(ingestDocument, null);
@@ -62,6 +66,9 @@ public void doExecute(
6266

6367
@Override
6468
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
65-
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException));
69+
mlCommonsClientAccessor.inferenceSentences(
70+
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
71+
ActionListener.wrap(handler::accept, onException)
72+
);
6673
}
6774
}

src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.common.annotations.VisibleForTesting;
2626

2727
import lombok.extern.log4j.Log4j2;
28+
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
2829
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;
2930

3031
/**
@@ -113,10 +114,13 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
113114
if (inferenceMap.isEmpty()) {
114115
handler.accept(ingestDocument, null);
115116
} else {
116-
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceMap, ActionListener.wrap(vectors -> {
117-
setVectorFieldsToDocument(ingestDocument, vectors);
118-
handler.accept(ingestDocument, null);
119-
}, e -> { handler.accept(null, e); }));
117+
mlCommonsClientAccessor.inferenceSentencesMap(
118+
new InferenceRequest.Builder(this.modelId).inputObjects(inferenceMap).build(),
119+
ActionListener.wrap(vectors -> {
120+
setVectorFieldsToDocument(ingestDocument, vectors);
121+
handler.accept(ingestDocument, null);
122+
}, e -> { handler.accept(null, e); })
123+
);
120124
}
121125
} catch (Exception e) {
122126
handler.accept(null, e);

src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.opensearch.action.search.SearchResponse;
1313
import org.opensearch.core.action.ActionListener;
1414
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
15+
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
1516
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
1617
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
1718
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
@@ -73,9 +74,9 @@ public void rescoreSearchResponse(
7374
List<?> ctxList = (List<?>) ctxObj;
7475
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList());
7576
mlCommonsClientAccessor.inferenceSimilarity(
76-
modelId,
77-
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD),
78-
contexts,
77+
new InferenceRequest.Builder(modelId).queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD))
78+
.inputTexts(contexts)
79+
.build(),
7980
listener
8081
);
8182
}

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,12 @@
5757
import lombok.Setter;
5858
import lombok.experimental.Accessors;
5959
import lombok.extern.log4j.Log4j2;
60+
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
6061

6162
/**
62-
* NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a
63-
* k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as
64-
* the query vector for the k-NN search.
63+
* NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a
64+
* wrapper around a k-NN vector query. It uses a ML language model to produce a dense vector from a
65+
* query string that is then used as the query vector for the k-NN search.
6566
*/
6667

6768
@Log4j2
@@ -86,6 +87,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
8687
static final ParseField K_FIELD = new ParseField("k");
8788

8889
private static final int DEFAULT_K = 10;
90+
private static final AsymmetricTextEmbeddingParameters QUERY_PARAMETERS = AsymmetricTextEmbeddingParameters.builder()
91+
.embeddingContentType(EmbeddingContentType.QUERY)
92+
.build();
8993

9094
private static MLCommonsClientAccessor ML_CLIENT;
9195

@@ -335,10 +339,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
335339
inferenceInput.put(INPUT_IMAGE, queryImage());
336340
}
337341
queryRewriteContext.registerAsyncAction(
338-
((client, actionListener) -> ML_CLIENT.inferenceSentences(
339-
modelId(),
340-
inferenceInput,
341-
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(),
342+
((client, actionListener) -> ML_CLIENT.inferenceSentencesMap(
343+
new InferenceRequest.Builder(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(),
342344
ActionListener.wrap(floatList -> {
343345
vectorSetOnce.set(vectorAsListToArray(floatList));
344346
actionListener.onResponse(null);
@@ -368,8 +370,12 @@ protected Query doToQuery(QueryShardContext queryShardContext) {
368370

369371
@Override
370372
protected boolean doEquals(NeuralQueryBuilder obj) {
371-
if (this == obj) return true;
372-
if (obj == null || getClass() != obj.getClass()) return false;
373+
if (this == obj) {
374+
return true;
375+
}
376+
if (obj == null || getClass() != obj.getClass()) {
377+
return false;
378+
}
373379
EqualsBuilder equalsBuilder = new EqualsBuilder();
374380
equalsBuilder.append(fieldName, obj.fieldName);
375381
equalsBuilder.append(queryText, obj.queryText);

src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.opensearch.index.query.QueryRewriteContext;
3838
import org.opensearch.index.query.QueryShardContext;
3939
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
40+
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
4041
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
4142
import org.opensearch.neuralsearch.util.TokenWeightUtil;
4243

@@ -341,8 +342,7 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
341342
// it splits the tokens using a threshold defined by a ratio of the maximum score of tokens, updating the token set
342343
// accordingly.
343344
return ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult(
344-
modelId(),
345-
List.of(queryText),
345+
new InferenceRequest.Builder(modelId()).inputTexts(List.of(queryText)).build(),
346346
ActionListener.wrap(mapResultList -> {
347347
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
348348
if (Objects.nonNull(twoPhaseSharedQueryToken)) {

0 commit comments

Comments
 (0)