Skip to content

Commit e77f881

Browse files
committed
implement batch document update scenario for text embedding processor (opensearch-project#1217)
Signed-off-by: Will Hwang <[email protected]>
1 parent d4b46c8 commit e77f881

File tree

10 files changed

+537
-77
lines changed

10 files changed

+537
-77
lines changed

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

Lines changed: 194 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import java.util.ArrayList;
88
import java.util.Arrays;
99
import java.util.Collection;
10-
import java.util.Collections;
1110
import java.util.Comparator;
1211
import java.util.HashMap;
1312
import java.util.Iterator;
@@ -26,6 +25,10 @@
2625
import org.apache.commons.lang3.StringUtils;
2726
import org.apache.commons.lang3.tuple.ImmutablePair;
2827
import org.apache.commons.lang3.tuple.Pair;
28+
import org.opensearch.action.get.MultiGetItemResponse;
29+
import org.opensearch.action.get.MultiGetRequest;
30+
import org.opensearch.action.get.MultiGetResponse;
31+
import org.opensearch.common.CheckedConsumer;
2932
import org.opensearch.common.collect.Tuple;
3033
import org.opensearch.core.action.ActionListener;
3134
import org.opensearch.core.common.util.CollectionUtils;
@@ -42,6 +45,7 @@
4245
import com.google.common.collect.ImmutableMap;
4346

4447
import lombok.extern.log4j.Log4j2;
48+
import org.opensearch.neuralsearch.processor.optimization.InferenceFilter;
4549
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;
4650

4751
/**
@@ -54,6 +58,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
5458

5559
public static final String MODEL_ID_FIELD = "model_id";
5660
public static final String FIELD_MAP_FIELD = "field_map";
61+
public static final String INDEX_FIELD = "_index";
62+
public static final String ID_FIELD = "_id";
5763
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
5864
if (v1 instanceof Collection && v2 instanceof Collection) {
5965
((Collection) v1).addAll((Collection) v2);
@@ -169,23 +175,67 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
169175
*/
170176
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);
171177

178+
/**
179+
* This is the function which does actual inference work for subBatchExecute interface.
180+
* @param ingestDocumentWrappers a list of IngestDocuments in a batch.
181+
* @param handler a callback handler to handle inference results which is a list of objects.
182+
*/
172183
@Override
173184
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
174-
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
175-
handler.accept(Collections.emptyList());
176-
return;
185+
try {
186+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
187+
handler.accept(ingestDocumentWrappers);
188+
return;
189+
}
190+
191+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
192+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
193+
if (inferenceList.isEmpty()) {
194+
handler.accept(ingestDocumentWrappers);
195+
return;
196+
}
197+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
198+
} catch (Exception e) {
199+
updateWithExceptions(ingestDocumentWrappers, handler, e);
177200
}
201+
}
178202

179-
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
180-
List<String> inferenceList = constructInferenceTexts(dataForInferences);
181-
if (inferenceList.isEmpty()) {
182-
handler.accept(ingestDocumentWrappers);
183-
return;
203+
/**
204+
* This is a helper function for subBatchExecute, which invokes doBatchExecute for given inference list.
205+
* @param ingestDocumentWrappers a list of IngestDocuments in a batch.
206+
* @param inferenceList a list of String for inference.
207+
* @param dataForInferences a list of data for inference, which includes ingestDocumentWrapper, processMap, inferenceList.
208+
*/
209+
protected void doSubBatchExecute(
210+
List<IngestDocumentWrapper> ingestDocumentWrappers,
211+
List<String> inferenceList,
212+
List<DataForInference> dataForInferences,
213+
Consumer<List<IngestDocumentWrapper>> handler
214+
) {
215+
try {
216+
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
217+
inferenceList = sortedResult.v1();
218+
Map<Integer, Integer> originalOrder = sortedResult.v2();
219+
doBatchExecute(
220+
inferenceList,
221+
results -> batchExecuteHandler(results, ingestDocumentWrappers, dataForInferences, originalOrder, handler),
222+
exception -> {
223+
updateWithExceptions(ingestDocumentWrappers, handler, exception);
224+
}
225+
);
226+
} catch (Exception e) {
227+
updateWithExceptions(ingestDocumentWrappers, handler, e);
184228
}
185-
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
186-
inferenceList = sortedResult.v1();
187-
Map<Integer, Integer> originalOrder = sortedResult.v2();
188-
doBatchExecute(inferenceList, results -> {
229+
}
230+
231+
private void batchExecuteHandler(
232+
List<?> results,
233+
List<IngestDocumentWrapper> ingestDocumentWrappers,
234+
List<DataForInference> dataForInferences,
235+
Map<Integer, Integer> originalOrder,
236+
Consumer<List<IngestDocumentWrapper>> handler
237+
) {
238+
try {
189239
int startIndex = 0;
190240
results = restoreToOriginalOrder(results, originalOrder);
191241
for (DataForInference dataForInference : dataForInferences) {
@@ -202,16 +252,9 @@ public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers,
202252
);
203253
}
204254
handler.accept(ingestDocumentWrappers);
205-
}, exception -> {
206-
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
207-
// The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
208-
// set exception to IngestDocumentWrapper which doesn't have exception before.
209-
if (ingestDocumentWrapper.getException() == null) {
210-
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
211-
}
212-
}
213-
handler.accept(ingestDocumentWrappers);
214-
});
255+
} catch (Exception e) {
256+
updateWithExceptions(ingestDocumentWrappers, handler, e);
257+
}
215258
}
216259

217260
private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> inferenceList) {
@@ -238,7 +281,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238281
return sortedResults;
239282
}
240283

241-
private List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
284+
protected List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
242285
List<String> inferenceTexts = new ArrayList<>();
243286
for (DataForInference dataForInference : dataForInferences) {
244287
if (dataForInference.getIngestDocumentWrapper().getException() != null
@@ -250,7 +293,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250293
return inferenceTexts;
251294
}
252295

253-
private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
296+
protected List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
254297
List<DataForInference> dataForInferences = new ArrayList<>();
255298
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
256299
Map<String, Object> processMap = null;
@@ -272,7 +315,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272315

273316
@Getter
274317
@AllArgsConstructor
275-
private static class DataForInference {
318+
protected static class DataForInference {
276319
private final IngestDocumentWrapper ingestDocumentWrapper;
277320
private final Map<String, Object> processMap;
278321
private final List<String> inferenceList;
@@ -415,6 +458,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415458
nlpResult.forEach(ingestDocument::setFieldValue);
416459
}
417460

461+
/**
462+
* This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
463+
* @param dataForInferences, list of data for inferences
464+
* */
465+
protected MultiGetRequest buildMultiGetRequest(List<DataForInference> dataForInferences) {
466+
MultiGetRequest multiGetRequest = new MultiGetRequest();
467+
for (DataForInference dataForInference : dataForInferences) {
468+
Object index = dataForInference.getIngestDocumentWrapper().getIngestDocument().getSourceAndMetadata().get(INDEX_FIELD);
469+
Object id = dataForInference.getIngestDocumentWrapper().getIngestDocument().getSourceAndMetadata().get(ID_FIELD);
470+
if (Objects.nonNull(index) && Objects.nonNull(id)) {
471+
multiGetRequest.add(index.toString(), id.toString());
472+
}
473+
}
474+
return multiGetRequest;
475+
}
476+
477+
/**
478+
* This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
479+
* @param multiGetItemResponses, array of responses from Multi Get Request
480+
* */
481+
protected Map<String, Map<String, Object>> createDocumentMap(MultiGetItemResponse[] multiGetItemResponses) {
482+
Map<String, Map<String, Object>> existingDocuments = new HashMap<>();
483+
for (MultiGetItemResponse item : multiGetItemResponses) {
484+
String id = item.getId();
485+
Map<String, Object> existingDocument = item.getResponse().getSourceAsMap();
486+
existingDocuments.put(id, existingDocument);
487+
}
488+
return existingDocuments;
489+
}
490+
418491
@SuppressWarnings({ "unchecked" })
419492
@VisibleForTesting
420493
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
@@ -504,6 +577,27 @@ private void processMapEntryValue(
504577
}
505578
}
506579

580+
// This method updates each ingestDocument with exceptions and accepts ingestDocumentWrappers.
581+
// Ingestion fails when exception occurs while updating
582+
protected void updateWithExceptions(
583+
List<IngestDocumentWrapper> ingestDocumentWrappers,
584+
Consumer<List<IngestDocumentWrapper>> handler,
585+
Exception e
586+
) {
587+
try {
588+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
589+
// The IngestDocumentWrapper might have already run into exception. So here we only
590+
// set exception to IngestDocumentWrapper which doesn't have exception before.
591+
if (ingestDocumentWrapper.getException() == null) {
592+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
593+
}
594+
}
595+
handler.accept(ingestDocumentWrappers);
596+
} catch (Exception ex) {
597+
handler.accept(null);
598+
}
599+
}
600+
507601
private void processMapEntryValue(
508602
List<?> results,
509603
IndexWrapper indexWrapper,
@@ -582,11 +676,65 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582676
List<Map<String, Object>> keyToResult = new ArrayList<>();
583677
sourceValue.stream()
584678
.filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where
585-
// sourceValue has been filtered
679+
// sourceValue has been filtered
586680
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
587681
return keyToResult;
588682
}
589683

684+
// This method validates and filters given inferenceList and dataForInferences after response is successfully retrieved from multi-get
685+
// operation.
686+
protected void multiGetResponseHandler(
687+
MultiGetResponse response,
688+
List<IngestDocumentWrapper> ingestDocumentWrappers,
689+
List<String> inferenceList,
690+
List<DataForInference> dataForInferences,
691+
Consumer<List<IngestDocumentWrapper>> handler,
692+
InferenceFilter inferenceFilter
693+
) {
694+
MultiGetItemResponse[] multiGetItemResponses = response.getResponses();
695+
if (multiGetItemResponses == null || multiGetItemResponses.length == 0) {
696+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
697+
return;
698+
}
699+
// create a map of documents with key: doc_id and value: doc
700+
Map<String, Map<String, Object>> existingDocuments = createDocumentMap(multiGetItemResponses);
701+
List<DataForInference> filteredDataForInference = filterDataForInference(inferenceFilter, dataForInferences, existingDocuments);
702+
List<String> filteredInferenceList = constructInferenceTexts(filteredDataForInference);
703+
if (filteredInferenceList.isEmpty()) {
704+
handler.accept(ingestDocumentWrappers);
705+
return;
706+
}
707+
doSubBatchExecute(ingestDocumentWrappers, filteredInferenceList, filteredDataForInference, handler);
708+
}
709+
710+
// This is a helper method to filter the given list of dataForInferences by comparing its documents with existingDocuments with
711+
// given inferenceFilter
712+
protected List<DataForInference> filterDataForInference(
713+
InferenceFilter inferenceFilter,
714+
List<DataForInference> dataForInferences,
715+
Map<String, Map<String, Object>> existingDocuments
716+
) {
717+
List<DataForInference> filteredDataForInference = new ArrayList<>();
718+
for (DataForInference dataForInference : dataForInferences) {
719+
IngestDocumentWrapper ingestDocumentWrapper = dataForInference.getIngestDocumentWrapper();
720+
Map<String, Object> processMap = dataForInference.getProcessMap();
721+
Map<String, Object> document = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata();
722+
Object id = document.get(ID_FIELD);
723+
// insert non-filtered dataForInference if existing document does not exist
724+
if (Objects.isNull(id) || existingDocuments.containsKey(id.toString()) == false) {
725+
filteredDataForInference.add(dataForInference);
726+
continue;
727+
}
728+
// filter dataForInference when existing document exists
729+
String docId = id.toString();
730+
Map<String, Object> existingDocument = existingDocuments.get(docId);
731+
Map<String, Object> filteredProcessMap = inferenceFilter.filter(existingDocument, document, processMap);
732+
List<String> filteredInferenceList = createInferenceList(filteredProcessMap);
733+
filteredDataForInference.add(new DataForInference(ingestDocumentWrapper, filteredProcessMap, filteredInferenceList));
734+
}
735+
return filteredDataForInference;
736+
}
737+
590738
/**
591739
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
592740
*
@@ -611,6 +759,25 @@ protected void makeInferenceCall(
611759
);
612760
}
613761

762+
protected <Response> ActionListener<Response> wrap(
763+
final CheckedConsumer<Response, ? extends Exception> onResponse,
764+
final Consumer<Exception> onFailure
765+
) {
766+
return new ActionListener<>() {
767+
public void onResponse(Response response) {
768+
try {
769+
onResponse.accept(response);
770+
} catch (Exception e) {
771+
this.onFailure(e);
772+
}
773+
}
774+
775+
public void onFailure(Exception e) {
776+
onFailure.accept(e);
777+
}
778+
};
779+
}
780+
614781
@Override
615782
public String getType() {
616783
return type;

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313

1414
import org.opensearch.action.get.GetAction;
1515
import org.opensearch.action.get.GetRequest;
16+
import org.opensearch.action.get.MultiGetAction;
1617
import org.opensearch.cluster.service.ClusterService;
1718
import org.opensearch.core.action.ActionListener;
19+
import org.opensearch.core.common.util.CollectionUtils;
1820
import org.opensearch.env.Environment;
1921
import org.opensearch.ingest.IngestDocument;
22+
import org.opensearch.ingest.IngestDocumentWrapper;
2023
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
2124

2225
import lombok.extern.log4j.Log4j2;
@@ -74,7 +77,7 @@ public void doExecute(
7477
// if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
7578
String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString();
7679
String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString();
77-
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> {
80+
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), wrap(response -> {
7881
final Map<String, Object> existingDocument = response.getSourceAsMap();
7982
if (existingDocument == null || existingDocument.isEmpty()) {
8083
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
@@ -106,4 +109,47 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
106109
ActionListener.wrap(handler::accept, onException)
107110
);
108111
}
112+
113+
@Override
114+
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
115+
try {
116+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
117+
handler.accept(ingestDocumentWrappers);
118+
return;
119+
}
120+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
121+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
122+
if (inferenceList.isEmpty()) {
123+
handler.accept(ingestDocumentWrappers);
124+
return;
125+
}
126+
// skip existing flag is turned off. Call doSubBatchExecute without filtering
127+
if (skipExisting == false) {
128+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
129+
return;
130+
}
131+
// skipExisting flag is turned on, eligible inference texts in dataForInferences will be compared and filtered after embeddings
132+
// are copied
133+
openSearchClient.execute(
134+
MultiGetAction.INSTANCE,
135+
buildMultiGetRequest(dataForInferences),
136+
wrap(
137+
response -> multiGetResponseHandler(
138+
response,
139+
ingestDocumentWrappers,
140+
inferenceList,
141+
dataForInferences,
142+
handler,
143+
textEmbeddingInferenceFilter
144+
),
145+
e -> {
146+
// When exception is thrown in for MultiGetAction, set exception to all ingestDocumentWrappers
147+
updateWithExceptions(ingestDocumentWrappers, handler, e);
148+
}
149+
)
150+
);
151+
} catch (Exception e) {
152+
updateWithExceptions(ingestDocumentWrappers, handler, e);
153+
}
154+
}
109155
}

0 commit comments

Comments
 (0)