7
7
import java .util .ArrayList ;
8
8
import java .util .Arrays ;
9
9
import java .util .Collection ;
10
- import java .util .Collections ;
11
10
import java .util .Comparator ;
12
11
import java .util .HashMap ;
13
12
import java .util .Iterator ;
26
25
import org .apache .commons .lang3 .StringUtils ;
27
26
import org .apache .commons .lang3 .tuple .ImmutablePair ;
28
27
import org .apache .commons .lang3 .tuple .Pair ;
28
+ import org .opensearch .action .get .MultiGetItemResponse ;
29
+ import org .opensearch .action .get .MultiGetRequest ;
30
+ import org .opensearch .action .get .MultiGetResponse ;
31
+ import org .opensearch .common .CheckedConsumer ;
29
32
import org .opensearch .common .collect .Tuple ;
30
33
import org .opensearch .core .action .ActionListener ;
31
34
import org .opensearch .core .common .util .CollectionUtils ;
42
45
import com .google .common .collect .ImmutableMap ;
43
46
44
47
import lombok .extern .log4j .Log4j2 ;
48
+ import org .opensearch .neuralsearch .processor .optimization .InferenceFilter ;
45
49
import org .opensearch .neuralsearch .util .ProcessorDocumentUtils ;
46
50
47
51
/**
@@ -54,6 +58,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
54
58
55
59
public static final String MODEL_ID_FIELD = "model_id" ;
56
60
public static final String FIELD_MAP_FIELD = "field_map" ;
61
+ public static final String INDEX_FIELD = "_index" ;
62
+ public static final String ID_FIELD = "_id" ;
57
63
private static final BiFunction <Object , Object , Object > REMAPPING_FUNCTION = (v1 , v2 ) -> {
58
64
if (v1 instanceof Collection && v2 instanceof Collection ) {
59
65
((Collection ) v1 ).addAll ((Collection ) v2 );
@@ -169,23 +175,67 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
169
175
*/
170
176
abstract void doBatchExecute (List <String > inferenceList , Consumer <List <?>> handler , Consumer <Exception > onException );
171
177
178
+ /**
179
+ * This is the function which does actual inference work for subBatchExecute interface.
180
+ * @param ingestDocumentWrappers a list of IngestDocuments in a batch.
181
+ * @param handler a callback handler to handle inference results which is a list of objects.
182
+ */
172
183
@ Override
173
184
public void subBatchExecute (List <IngestDocumentWrapper > ingestDocumentWrappers , Consumer <List <IngestDocumentWrapper >> handler ) {
174
- if (CollectionUtils .isEmpty (ingestDocumentWrappers )) {
175
- handler .accept (Collections .emptyList ());
176
- return ;
185
+ try {
186
+ if (CollectionUtils .isEmpty (ingestDocumentWrappers )) {
187
+ handler .accept (ingestDocumentWrappers );
188
+ return ;
189
+ }
190
+
191
+ List <DataForInference > dataForInferences = getDataForInference (ingestDocumentWrappers );
192
+ List <String > inferenceList = constructInferenceTexts (dataForInferences );
193
+ if (inferenceList .isEmpty ()) {
194
+ handler .accept (ingestDocumentWrappers );
195
+ return ;
196
+ }
197
+ doSubBatchExecute (ingestDocumentWrappers , inferenceList , dataForInferences , handler );
198
+ } catch (Exception e ) {
199
+ updateWithExceptions (ingestDocumentWrappers , handler , e );
177
200
}
201
+ }
178
202
179
- List <DataForInference > dataForInferences = getDataForInference (ingestDocumentWrappers );
180
- List <String > inferenceList = constructInferenceTexts (dataForInferences );
181
- if (inferenceList .isEmpty ()) {
182
- handler .accept (ingestDocumentWrappers );
183
- return ;
203
+ /**
204
+ * This is a helper function for subBatchExecute, which invokes doBatchExecute for given inference list.
205
+ * @param ingestDocumentWrappers a list of IngestDocuments in a batch.
206
+ * @param inferenceList a list of String for inference.
207
+ * @param dataForInferences a list of data for inference, which includes ingestDocumentWrapper, processMap, inferenceList.
208
+ */
209
+ protected void doSubBatchExecute (
210
+ List <IngestDocumentWrapper > ingestDocumentWrappers ,
211
+ List <String > inferenceList ,
212
+ List <DataForInference > dataForInferences ,
213
+ Consumer <List <IngestDocumentWrapper >> handler
214
+ ) {
215
+ try {
216
+ Tuple <List <String >, Map <Integer , Integer >> sortedResult = sortByLengthAndReturnOriginalOrder (inferenceList );
217
+ inferenceList = sortedResult .v1 ();
218
+ Map <Integer , Integer > originalOrder = sortedResult .v2 ();
219
+ doBatchExecute (
220
+ inferenceList ,
221
+ results -> batchExecuteHandler (results , ingestDocumentWrappers , dataForInferences , originalOrder , handler ),
222
+ exception -> {
223
+ updateWithExceptions (ingestDocumentWrappers , handler , exception );
224
+ }
225
+ );
226
+ } catch (Exception e ) {
227
+ updateWithExceptions (ingestDocumentWrappers , handler , e );
184
228
}
185
- Tuple <List <String >, Map <Integer , Integer >> sortedResult = sortByLengthAndReturnOriginalOrder (inferenceList );
186
- inferenceList = sortedResult .v1 ();
187
- Map <Integer , Integer > originalOrder = sortedResult .v2 ();
188
- doBatchExecute (inferenceList , results -> {
229
+ }
230
+
231
+ private void batchExecuteHandler (
232
+ List <?> results ,
233
+ List <IngestDocumentWrapper > ingestDocumentWrappers ,
234
+ List <DataForInference > dataForInferences ,
235
+ Map <Integer , Integer > originalOrder ,
236
+ Consumer <List <IngestDocumentWrapper >> handler
237
+ ) {
238
+ try {
189
239
int startIndex = 0 ;
190
240
results = restoreToOriginalOrder (results , originalOrder );
191
241
for (DataForInference dataForInference : dataForInferences ) {
@@ -202,16 +252,9 @@ public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers,
202
252
);
203
253
}
204
254
handler .accept (ingestDocumentWrappers );
205
- }, exception -> {
206
- for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
207
- // The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
208
- // set exception to IngestDocumentWrapper which doesn't have exception before.
209
- if (ingestDocumentWrapper .getException () == null ) {
210
- ingestDocumentWrapper .update (ingestDocumentWrapper .getIngestDocument (), exception );
211
- }
212
- }
213
- handler .accept (ingestDocumentWrappers );
214
- });
255
+ } catch (Exception e ) {
256
+ updateWithExceptions (ingestDocumentWrappers , handler , e );
257
+ }
215
258
}
216
259
217
260
private Tuple <List <String >, Map <Integer , Integer >> sortByLengthAndReturnOriginalOrder (List <String > inferenceList ) {
@@ -238,7 +281,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238
281
return sortedResults ;
239
282
}
240
283
241
- private List <String > constructInferenceTexts (List <DataForInference > dataForInferences ) {
284
+ protected List <String > constructInferenceTexts (List <DataForInference > dataForInferences ) {
242
285
List <String > inferenceTexts = new ArrayList <>();
243
286
for (DataForInference dataForInference : dataForInferences ) {
244
287
if (dataForInference .getIngestDocumentWrapper ().getException () != null
@@ -250,7 +293,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250
293
return inferenceTexts ;
251
294
}
252
295
253
- private List <DataForInference > getDataForInference (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
296
+ protected List <DataForInference > getDataForInference (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
254
297
List <DataForInference > dataForInferences = new ArrayList <>();
255
298
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
256
299
Map <String , Object > processMap = null ;
@@ -272,7 +315,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272
315
273
316
@ Getter
274
317
@ AllArgsConstructor
275
- private static class DataForInference {
318
+ protected static class DataForInference {
276
319
private final IngestDocumentWrapper ingestDocumentWrapper ;
277
320
private final Map <String , Object > processMap ;
278
321
private final List <String > inferenceList ;
@@ -415,6 +458,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415
458
nlpResult .forEach (ingestDocument ::setFieldValue );
416
459
}
417
460
461
+ /**
462
+ * This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
463
+ * @param dataForInferences, list of data for inferences
464
+ * */
465
+ protected MultiGetRequest buildMultiGetRequest (List <DataForInference > dataForInferences ) {
466
+ MultiGetRequest multiGetRequest = new MultiGetRequest ();
467
+ for (DataForInference dataForInference : dataForInferences ) {
468
+ Object index = dataForInference .getIngestDocumentWrapper ().getIngestDocument ().getSourceAndMetadata ().get (INDEX_FIELD );
469
+ Object id = dataForInference .getIngestDocumentWrapper ().getIngestDocument ().getSourceAndMetadata ().get (ID_FIELD );
470
+ if (Objects .nonNull (index ) && Objects .nonNull (id )) {
471
+ multiGetRequest .add (index .toString (), id .toString ());
472
+ }
473
+ }
474
+ return multiGetRequest ;
475
+ }
476
+
477
+ /**
478
+ * This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
479
+ * @param multiGetItemResponses, array of responses from Multi Get Request
480
+ * */
481
+ protected Map <String , Map <String , Object >> createDocumentMap (MultiGetItemResponse [] multiGetItemResponses ) {
482
+ Map <String , Map <String , Object >> existingDocuments = new HashMap <>();
483
+ for (MultiGetItemResponse item : multiGetItemResponses ) {
484
+ String id = item .getId ();
485
+ Map <String , Object > existingDocument = item .getResponse ().getSourceAsMap ();
486
+ existingDocuments .put (id , existingDocument );
487
+ }
488
+ return existingDocuments ;
489
+ }
490
+
418
491
@ SuppressWarnings ({ "unchecked" })
419
492
@ VisibleForTesting
420
493
Map <String , Object > buildNLPResult (Map <String , Object > processorMap , List <?> results , Map <String , Object > sourceAndMetadataMap ) {
@@ -504,6 +577,27 @@ private void processMapEntryValue(
504
577
}
505
578
}
506
579
580
+ // This method updates each ingestDocument with exceptions and accepts ingestDocumentWrappers.
581
+ // Ingestion fails when exception occurs while updating
582
+ protected void updateWithExceptions (
583
+ List <IngestDocumentWrapper > ingestDocumentWrappers ,
584
+ Consumer <List <IngestDocumentWrapper >> handler ,
585
+ Exception e
586
+ ) {
587
+ try {
588
+ for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
589
+ // The IngestDocumentWrapper might have already run into exception. So here we only
590
+ // set exception to IngestDocumentWrapper which doesn't have exception before.
591
+ if (ingestDocumentWrapper .getException () == null ) {
592
+ ingestDocumentWrapper .update (ingestDocumentWrapper .getIngestDocument (), e );
593
+ }
594
+ }
595
+ handler .accept (ingestDocumentWrappers );
596
+ } catch (Exception ex ) {
597
+ handler .accept (null );
598
+ }
599
+ }
600
+
507
601
private void processMapEntryValue (
508
602
List <?> results ,
509
603
IndexWrapper indexWrapper ,
@@ -582,11 +676,65 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582
676
List <Map <String , Object >> keyToResult = new ArrayList <>();
583
677
sourceValue .stream ()
584
678
.filter (Objects ::nonNull ) // explicit null check is required since sourceValue can contain null values in cases where
585
- // sourceValue has been filtered
679
+ // sourceValue has been filtered
586
680
.forEachOrdered (x -> keyToResult .add (ImmutableMap .of (listTypeNestedMapKey , results .get (indexWrapper .index ++))));
587
681
return keyToResult ;
588
682
}
589
683
684
+ // This method validates and filters given inferenceList and dataForInferences after response is successfully retrieved from multi-get
685
+ // operation.
686
+ protected void multiGetResponseHandler (
687
+ MultiGetResponse response ,
688
+ List <IngestDocumentWrapper > ingestDocumentWrappers ,
689
+ List <String > inferenceList ,
690
+ List <DataForInference > dataForInferences ,
691
+ Consumer <List <IngestDocumentWrapper >> handler ,
692
+ InferenceFilter inferenceFilter
693
+ ) {
694
+ MultiGetItemResponse [] multiGetItemResponses = response .getResponses ();
695
+ if (multiGetItemResponses == null || multiGetItemResponses .length == 0 ) {
696
+ doSubBatchExecute (ingestDocumentWrappers , inferenceList , dataForInferences , handler );
697
+ return ;
698
+ }
699
+ // create a map of documents with key: doc_id and value: doc
700
+ Map <String , Map <String , Object >> existingDocuments = createDocumentMap (multiGetItemResponses );
701
+ List <DataForInference > filteredDataForInference = filterDataForInference (inferenceFilter , dataForInferences , existingDocuments );
702
+ List <String > filteredInferenceList = constructInferenceTexts (filteredDataForInference );
703
+ if (filteredInferenceList .isEmpty ()) {
704
+ handler .accept (ingestDocumentWrappers );
705
+ return ;
706
+ }
707
+ doSubBatchExecute (ingestDocumentWrappers , filteredInferenceList , filteredDataForInference , handler );
708
+ }
709
+
710
+ // This is a helper method to filter the given list of dataForInferences by comparing its documents with existingDocuments with
711
+ // given inferenceFilter
712
+ protected List <DataForInference > filterDataForInference (
713
+ InferenceFilter inferenceFilter ,
714
+ List <DataForInference > dataForInferences ,
715
+ Map <String , Map <String , Object >> existingDocuments
716
+ ) {
717
+ List <DataForInference > filteredDataForInference = new ArrayList <>();
718
+ for (DataForInference dataForInference : dataForInferences ) {
719
+ IngestDocumentWrapper ingestDocumentWrapper = dataForInference .getIngestDocumentWrapper ();
720
+ Map <String , Object > processMap = dataForInference .getProcessMap ();
721
+ Map <String , Object > document = ingestDocumentWrapper .getIngestDocument ().getSourceAndMetadata ();
722
+ Object id = document .get (ID_FIELD );
723
+ // insert non-filtered dataForInference if existing document does not exist
724
+ if (Objects .isNull (id ) || existingDocuments .containsKey (id .toString ()) == false ) {
725
+ filteredDataForInference .add (dataForInference );
726
+ continue ;
727
+ }
728
+ // filter dataForInference when existing document exists
729
+ String docId = id .toString ();
730
+ Map <String , Object > existingDocument = existingDocuments .get (docId );
731
+ Map <String , Object > filteredProcessMap = inferenceFilter .filter (existingDocument , document , processMap );
732
+ List <String > filteredInferenceList = createInferenceList (filteredProcessMap );
733
+ filteredDataForInference .add (new DataForInference (ingestDocumentWrapper , filteredProcessMap , filteredInferenceList ));
734
+ }
735
+ return filteredDataForInference ;
736
+ }
737
+
590
738
/**
591
739
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
592
740
*
@@ -611,6 +759,25 @@ protected void makeInferenceCall(
611
759
);
612
760
}
613
761
762
+ protected <Response > ActionListener <Response > wrap (
763
+ final CheckedConsumer <Response , ? extends Exception > onResponse ,
764
+ final Consumer <Exception > onFailure
765
+ ) {
766
+ return new ActionListener <>() {
767
+ public void onResponse (Response response ) {
768
+ try {
769
+ onResponse .accept (response );
770
+ } catch (Exception e ) {
771
+ this .onFailure (e );
772
+ }
773
+ }
774
+
775
+ public void onFailure (Exception e ) {
776
+ onFailure .accept (e );
777
+ }
778
+ };
779
+ }
780
+
614
781
@ Override
615
782
public String getType () {
616
783
return type ;
0 commit comments