Skip to content

Commit e638504

Browse files
will-hwangryanbogan
authored andcommitted
Implement Optimized embedding generation in text embedding processor (opensearch-project#1238)
* implement single document update scenario for text embedding processor (opensearch-project#1191) Signed-off-by: Will Hwang <[email protected]> * implement batch document update scenario for text embedding processor (opensearch-project#1217) Signed-off-by: Will Hwang <[email protected]> --------- Signed-off-by: Will Hwang <[email protected]>
1 parent dff2a71 commit e638504

25 files changed

+2899
-242
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1010
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
1111
- Add Z Score normalization technique ([#1224](https://github.com/opensearch-project/neural-search/pull/1224))
1212
- Support semantic sentence highlighter ([#1193](https://github.com/opensearch-project/neural-search/pull/1193))
13+
- Optimize embedding generation in Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
1314

1415
### Enhancements
1516

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
140140
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
141141
return Map.of(
142142
TextEmbeddingProcessor.TYPE,
143-
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
143+
new TextEmbeddingProcessorFactory(
144+
parameters.client,
145+
clientAccessor,
146+
parameters.env,
147+
parameters.ingestService.getClusterService()
148+
),
144149
SparseEncodingProcessor.TYPE,
145150
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
146151
TextImageEmbeddingProcessor.TYPE,

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

Lines changed: 241 additions & 40 deletions
Large diffs are not rendered by default.

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,20 @@
99
import java.util.function.BiConsumer;
1010
import java.util.function.Consumer;
1111

12+
import org.opensearch.action.get.GetAction;
13+
import org.opensearch.action.get.GetRequest;
14+
import org.opensearch.action.get.MultiGetAction;
1215
import org.opensearch.cluster.service.ClusterService;
1316
import org.opensearch.core.action.ActionListener;
17+
import org.opensearch.core.common.util.CollectionUtils;
1418
import org.opensearch.env.Environment;
1519
import org.opensearch.ingest.IngestDocument;
20+
import org.opensearch.ingest.IngestDocumentWrapper;
1621
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
1722

1823
import lombok.extern.log4j.Log4j2;
24+
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
25+
import org.opensearch.transport.client.OpenSearchClient;
1926

2027
/**
2128
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
@@ -26,33 +33,57 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
2633

2734
public static final String TYPE = "text_embedding";
2835
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
36+
public static final String SKIP_EXISTING = "skip_existing";
37+
public static final boolean DEFAULT_SKIP_EXISTING = false;
38+
private static final String INDEX_FIELD = "_index";
39+
private static final String ID_FIELD = "_id";
40+
private final OpenSearchClient openSearchClient;
41+
private final boolean skipExisting;
42+
private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter;
2943

3044
public TextEmbeddingProcessor(
3145
String tag,
3246
String description,
3347
int batchSize,
3448
String modelId,
3549
Map<String, Object> fieldMap,
50+
boolean skipExisting,
51+
TextEmbeddingInferenceFilter textEmbeddingInferenceFilter,
52+
OpenSearchClient openSearchClient,
3653
MLCommonsClientAccessor clientAccessor,
3754
Environment environment,
3855
ClusterService clusterService
3956
) {
4057
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
58+
this.skipExisting = skipExisting;
59+
this.textEmbeddingInferenceFilter = textEmbeddingInferenceFilter;
60+
this.openSearchClient = openSearchClient;
4161
}
4262

4363
@Override
4464
public void doExecute(
4565
IngestDocument ingestDocument,
46-
Map<String, Object> ProcessMap,
66+
Map<String, Object> processMap,
4767
List<String> inferenceList,
4868
BiConsumer<IngestDocument, Exception> handler
4969
) {
50-
mlCommonsClientAccessor.inferenceSentences(
51-
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
52-
ActionListener.wrap(vectors -> {
53-
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
54-
handler.accept(ingestDocument, null);
55-
}, e -> { handler.accept(null, e); })
70+
// skip existing flag is turned off. Call model inference without filtering
71+
if (skipExisting == false) {
72+
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
73+
return;
74+
}
75+
// if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
76+
String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString();
77+
String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString();
78+
openSearchClient.execute(
79+
GetAction.INSTANCE,
80+
new GetRequest(index, id),
81+
ActionListener.wrap(
82+
response -> getResponseHandler(response, ingestDocument, processMap, inferenceList, handler, textEmbeddingInferenceFilter),
83+
e -> {
84+
handler.accept(null, e);
85+
}
86+
)
5687
);
5788
}
5889

@@ -63,4 +94,47 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
6394
ActionListener.wrap(handler::accept, onException)
6495
);
6596
}
97+
98+
@Override
99+
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
100+
try {
101+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
102+
handler.accept(ingestDocumentWrappers);
103+
return;
104+
}
105+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
106+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
107+
if (inferenceList.isEmpty()) {
108+
handler.accept(ingestDocumentWrappers);
109+
return;
110+
}
111+
// skip existing flag is turned off. Call doSubBatchExecute without filtering
112+
if (skipExisting == false) {
113+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
114+
return;
115+
}
116+
// skipExisting flag is turned on, eligible inference texts in dataForInferences will be compared and filtered after embeddings
117+
// are copied
118+
openSearchClient.execute(
119+
MultiGetAction.INSTANCE,
120+
buildMultiGetRequest(dataForInferences),
121+
ActionListener.wrap(
122+
response -> multiGetResponseHandler(
123+
response,
124+
ingestDocumentWrappers,
125+
inferenceList,
126+
dataForInferences,
127+
handler,
128+
textEmbeddingInferenceFilter
129+
),
130+
e -> {
131+
// When exception is thrown in for MultiGetAction, set exception to all ingestDocumentWrappers
132+
updateWithExceptions(getIngestDocumentWrappers(dataForInferences), handler, e);
133+
}
134+
)
135+
);
136+
} catch (Exception e) {
137+
updateWithExceptions(ingestDocumentWrappers, handler, e);
138+
}
139+
}
66140
}

src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
*/
55
package org.opensearch.neuralsearch.processor.factory;
66

7+
import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty;
78
import static org.opensearch.ingest.ConfigurationUtils.readMap;
89
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
10+
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.SKIP_EXISTING;
11+
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_SKIP_EXISTING;
912
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;
1013
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD;
1114
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
@@ -17,24 +20,30 @@
1720
import org.opensearch.ingest.AbstractBatchingProcessor;
1821
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
1922
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
23+
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
24+
import org.opensearch.transport.client.OpenSearchClient;
2025

2126
/**
2227
* Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
2328
*/
2429
public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcessor.Factory {
2530

31+
private final OpenSearchClient openSearchClient;
32+
2633
private final MLCommonsClientAccessor clientAccessor;
2734

2835
private final Environment environment;
2936

3037
private final ClusterService clusterService;
3138

3239
public TextEmbeddingProcessorFactory(
40+
final OpenSearchClient openSearchClient,
3341
final MLCommonsClientAccessor clientAccessor,
3442
final Environment environment,
3543
final ClusterService clusterService
3644
) {
3745
super(TYPE);
46+
this.openSearchClient = openSearchClient;
3847
this.clientAccessor = clientAccessor;
3948
this.environment = environment;
4049
this.clusterService = clusterService;
@@ -43,7 +52,20 @@ public TextEmbeddingProcessorFactory(
4352
@Override
4453
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
4554
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
46-
Map<String, Object> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
47-
return new TextEmbeddingProcessor(tag, description, batchSize, modelId, filedMap, clientAccessor, environment, clusterService);
55+
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
56+
boolean skipExisting = readBooleanProperty(TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING);
57+
return new TextEmbeddingProcessor(
58+
tag,
59+
description,
60+
batchSize,
61+
modelId,
62+
fieldMap,
63+
skipExisting,
64+
skipExisting ? new TextEmbeddingInferenceFilter(fieldMap) : null,
65+
openSearchClient,
66+
clientAccessor,
67+
environment,
68+
clusterService
69+
);
4870
}
4971
}

0 commit comments

Comments
 (0)