Skip to content

Commit 61347b0

Browse files
committed
using lombok in InferenceRequest DTO
Signed-off-by: br3no <[email protected]>
1 parent 2ac1f09 commit 61347b0

File tree

9 files changed

+64
-150
lines changed

9 files changed

+64
-150
lines changed

src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

Lines changed: 17 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import java.util.function.Consumer;
1515
import java.util.stream.Collectors;
1616

17+
import lombok.Builder;
18+
import lombok.Getter;
19+
import lombok.Singular;
1720
import org.opensearch.common.CheckedConsumer;
1821
import org.opensearch.common.cache.Cache;
1922
import org.opensearch.common.cache.CacheBuilder;
@@ -47,19 +50,24 @@
4750
@Log4j2
4851
public class MLCommonsClientAccessor {
4952

53+
public static final int MAXIMUM_CACHE_ENTRIES = 10_000;
54+
5055
/**
5156
* Inference parameters for calls to the MLCommons client.
5257
*/
58+
@Getter
59+
@Builder
5360
public static class InferenceRequest {
5461

5562
private static final List<String> DEFAULT_TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
5663

57-
private final String modelId;
58-
private final List<String> inputTexts;
59-
private final MLAlgoParams mlAlgoParams;
60-
private final List<String> targetResponseFilters;
61-
private final Map<String, String> inputObjects;
62-
private final String queryText;
64+
private final String modelId; // required
65+
@Singular
66+
private List<String> inputTexts;
67+
private MLAlgoParams mlAlgoParams;
68+
private List<String> targetResponseFilters;
69+
private Map<String, String> inputObjects;
70+
private String queryText;
6371

6472
public InferenceRequest(
6573
@NonNull String modelId,
@@ -76,124 +84,12 @@ public InferenceRequest(
7684
this.inputObjects = inputObjects;
7785
this.queryText = queryText;
7886
}
79-
80-
public String getModelId() {
81-
return modelId;
82-
}
83-
84-
public List<String> getInputTexts() {
85-
return inputTexts;
86-
}
87-
88-
public MLAlgoParams getMlAlgoParams() {
89-
return mlAlgoParams;
90-
}
91-
92-
public List<String> getTargetResponseFilters() {
93-
return targetResponseFilters;
94-
}
95-
96-
public Map<String, String> getInputObjects() {
97-
return inputObjects;
98-
}
99-
100-
public String getQueryText() {
101-
return queryText;
102-
}
103-
104-
/**
105-
* Builder for {@link InferenceRequest}. Supports fluent construction of the request object.
106-
*/
107-
public static class Builder {
108-
109-
private final String modelId;
110-
private List<String> inputTexts;
111-
private MLAlgoParams mlAlgoParams;
112-
private List<String> targetResponseFilters;
113-
private Map<String, String> inputObjects;
114-
private String queryText;
115-
116-
/**
117-
* @param modelId the model id to use for inference
118-
*/
119-
public Builder(String modelId) {
120-
this.modelId = modelId;
121-
}
122-
123-
/**
124-
* @param inputTexts a {@link List} of input texts to use for inference
125-
* @return this builder
126-
*/
127-
public Builder inputTexts(List<String> inputTexts) {
128-
this.inputTexts = inputTexts;
129-
return this;
130-
}
131-
132-
/**
133-
* @param inputText an input text to add to the list of input texts. Repeated calls will add
134-
* more input texts.
135-
* @return this builder
136-
*/
137-
public Builder inputText(String inputText) {
138-
if (this.inputTexts != null) {
139-
this.inputTexts.add(inputText);
140-
} else {
141-
this.inputTexts = new ArrayList<>();
142-
this.inputTexts.add(inputText);
143-
}
144-
return this;
145-
}
146-
147-
/**
148-
* @param mlAlgoParams the {@link MLAlgoParams} to use for inference.
149-
* @return this builder
150-
*/
151-
public Builder mlAlgoParams(MLAlgoParams mlAlgoParams) {
152-
this.mlAlgoParams = mlAlgoParams;
153-
return this;
154-
}
155-
156-
/**
157-
* @param targetResponseFilters a {@link List} of target response filters to use for
158-
* inference
159-
* @return this builder
160-
*/
161-
public Builder targetResponseFilters(List<String> targetResponseFilters) {
162-
this.targetResponseFilters = targetResponseFilters;
163-
return this;
164-
}
165-
166-
/**
167-
* @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs
168-
* to happen
169-
* @return this builder
170-
*/
171-
public Builder inputObjects(Map<String, String> inputObjects) {
172-
this.inputObjects = inputObjects;
173-
return this;
174-
}
175-
176-
/**
177-
* @param queryText the query text to use for similarity inference
178-
* @return this builder
179-
*/
180-
public Builder queryText(String queryText) {
181-
this.queryText = queryText;
182-
return this;
183-
}
184-
185-
/**
186-
* @return a new {@link InferenceRequest} object with the parameters set in this builder
187-
*/
188-
public InferenceRequest build() {
189-
return new InferenceRequest(modelId, inputTexts, mlAlgoParams, targetResponseFilters, inputObjects, queryText);
190-
}
191-
192-
}
19387
}
19488

19589
private final MachineLearningNodeClient mlClient;
196-
private final Cache<String, Boolean> modelAsymmetryCache = CacheBuilder.<String, Boolean>builder().setMaximumWeight(10_000).build();
90+
private final Cache<String, Boolean> modelAsymmetryCache = CacheBuilder.<String, Boolean>builder()
91+
.setMaximumWeight(MAXIMUM_CACHE_ENTRIES)
92+
.build();
19793

19894
/**
19995
* Wrapper around {@link #inferenceSentencesMap} that expects a single input text and produces a

src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public void doExecute(
5050
BiConsumer<IngestDocument, Exception> handler
5151
) {
5252
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
53-
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(),
53+
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
5454
ActionListener.wrap(resultMaps -> {
5555
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
5656
handler.accept(ingestDocument, null);
@@ -61,7 +61,7 @@ public void doExecute(
6161
@Override
6262
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
6363
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
64-
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(),
64+
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
6565
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
6666
);
6767
}

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public void doExecute(
5656
BiConsumer<IngestDocument, Exception> handler
5757
) {
5858
mlCommonsClientAccessor.inferenceSentences(
59-
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
59+
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
6060
ActionListener.wrap(vectors -> {
6161
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
6262
handler.accept(ingestDocument, null);
@@ -67,7 +67,7 @@ public void doExecute(
6767
@Override
6868
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
6969
mlCommonsClientAccessor.inferenceSentences(
70-
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
70+
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
7171
ActionListener.wrap(handler::accept, onException)
7272
);
7373
}

src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
115115
handler.accept(ingestDocument, null);
116116
} else {
117117
mlCommonsClientAccessor.inferenceSentencesMap(
118-
new InferenceRequest.Builder(this.modelId).inputObjects(inferenceMap).build(),
118+
InferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(),
119119
ActionListener.wrap(vectors -> {
120120
setVectorFieldsToDocument(ingestDocument, vectors);
121121
handler.accept(ingestDocument, null);

src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ public void rescoreSearchResponse(
7474
List<?> ctxList = (List<?>) ctxObj;
7575
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList());
7676
mlCommonsClientAccessor.inferenceSimilarity(
77-
new InferenceRequest.Builder(modelId).queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD))
77+
InferenceRequest.builder()
78+
.modelId(modelId)
79+
.queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD))
7880
.inputTexts(contexts)
7981
.build(),
8082
listener

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
340340
}
341341
queryRewriteContext.registerAsyncAction(
342342
((client, actionListener) -> ML_CLIENT.inferenceSentencesMap(
343-
new InferenceRequest.Builder(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(),
343+
InferenceRequest.builder().modelId(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(),
344344
ActionListener.wrap(floatList -> {
345345
vectorSetOnce.set(vectorAsListToArray(floatList));
346346
actionListener.onResponse(null);

src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
342342
// it splits the tokens using a threshold defined by a ratio of the maximum score of tokens, updating the token set
343343
// accordingly.
344344
return ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult(
345-
new InferenceRequest.Builder(modelId()).inputTexts(List.of(queryText)).build(),
345+
InferenceRequest.builder().modelId(modelId()).inputTexts(List.of(queryText)).build(),
346346
ActionListener.wrap(mapResultList -> {
347347
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
348348
if (Objects.nonNull(twoPhaseSharedQueryToken)) {

0 commit comments

Comments
 (0)