9
9
import java .util .function .BiConsumer ;
10
10
import java .util .function .Consumer ;
11
11
12
+ import org .opensearch .action .get .GetAction ;
13
+ import org .opensearch .action .get .GetRequest ;
14
+ import org .opensearch .action .get .MultiGetAction ;
12
15
import org .opensearch .cluster .service .ClusterService ;
13
16
import org .opensearch .core .action .ActionListener ;
17
+ import org .opensearch .core .common .util .CollectionUtils ;
14
18
import org .opensearch .env .Environment ;
15
19
import org .opensearch .ingest .IngestDocument ;
20
+ import org .opensearch .ingest .IngestDocumentWrapper ;
16
21
import org .opensearch .neuralsearch .ml .MLCommonsClientAccessor ;
17
22
18
23
import lombok .extern .log4j .Log4j2 ;
24
+ import org .opensearch .neuralsearch .processor .optimization .TextEmbeddingInferenceFilter ;
25
+ import org .opensearch .transport .client .OpenSearchClient ;
19
26
20
27
/**
21
28
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
@@ -26,33 +33,57 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
26
33
27
34
public static final String TYPE = "text_embedding" ;
28
35
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn" ;
36
+ public static final String SKIP_EXISTING = "skip_existing" ;
37
+ public static final boolean DEFAULT_SKIP_EXISTING = false ;
38
+ private static final String INDEX_FIELD = "_index" ;
39
+ private static final String ID_FIELD = "_id" ;
40
+ private final OpenSearchClient openSearchClient ;
41
+ private final boolean skipExisting ;
42
+ private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter ;
29
43
30
44
public TextEmbeddingProcessor (
31
45
String tag ,
32
46
String description ,
33
47
int batchSize ,
34
48
String modelId ,
35
49
Map <String , Object > fieldMap ,
50
+ boolean skipExisting ,
51
+ TextEmbeddingInferenceFilter textEmbeddingInferenceFilter ,
52
+ OpenSearchClient openSearchClient ,
36
53
MLCommonsClientAccessor clientAccessor ,
37
54
Environment environment ,
38
55
ClusterService clusterService
39
56
) {
40
57
super (tag , description , batchSize , TYPE , LIST_TYPE_NESTED_MAP_KEY , modelId , fieldMap , clientAccessor , environment , clusterService );
58
+ this .skipExisting = skipExisting ;
59
+ this .textEmbeddingInferenceFilter = textEmbeddingInferenceFilter ;
60
+ this .openSearchClient = openSearchClient ;
41
61
}
42
62
43
63
@ Override
44
64
public void doExecute (
45
65
IngestDocument ingestDocument ,
46
- Map <String , Object > ProcessMap ,
66
+ Map <String , Object > processMap ,
47
67
List <String > inferenceList ,
48
68
BiConsumer <IngestDocument , Exception > handler
49
69
) {
50
- mlCommonsClientAccessor .inferenceSentences (
51
- TextInferenceRequest .builder ().modelId (this .modelId ).inputTexts (inferenceList ).build (),
52
- ActionListener .wrap (vectors -> {
53
- setVectorFieldsToDocument (ingestDocument , ProcessMap , vectors );
54
- handler .accept (ingestDocument , null );
55
- }, e -> { handler .accept (null , e ); })
70
+ // skip existing flag is turned off. Call model inference without filtering
71
+ if (skipExisting == false ) {
72
+ makeInferenceCall (ingestDocument , processMap , inferenceList , handler );
73
+ return ;
74
+ }
75
+ // if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
76
+ String index = ingestDocument .getSourceAndMetadata ().get (INDEX_FIELD ).toString ();
77
+ String id = ingestDocument .getSourceAndMetadata ().get (ID_FIELD ).toString ();
78
+ openSearchClient .execute (
79
+ GetAction .INSTANCE ,
80
+ new GetRequest (index , id ),
81
+ ActionListener .wrap (
82
+ response -> getResponseHandler (response , ingestDocument , processMap , inferenceList , handler , textEmbeddingInferenceFilter ),
83
+ e -> {
84
+ handler .accept (null , e );
85
+ }
86
+ )
56
87
);
57
88
}
58
89
@@ -63,4 +94,47 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
63
94
ActionListener .wrap (handler ::accept , onException )
64
95
);
65
96
}
97
+
98
+ @ Override
99
+ public void subBatchExecute (List <IngestDocumentWrapper > ingestDocumentWrappers , Consumer <List <IngestDocumentWrapper >> handler ) {
100
+ try {
101
+ if (CollectionUtils .isEmpty (ingestDocumentWrappers )) {
102
+ handler .accept (ingestDocumentWrappers );
103
+ return ;
104
+ }
105
+ List <DataForInference > dataForInferences = getDataForInference (ingestDocumentWrappers );
106
+ List <String > inferenceList = constructInferenceTexts (dataForInferences );
107
+ if (inferenceList .isEmpty ()) {
108
+ handler .accept (ingestDocumentWrappers );
109
+ return ;
110
+ }
111
+ // skip existing flag is turned off. Call doSubBatchExecute without filtering
112
+ if (skipExisting == false ) {
113
+ doSubBatchExecute (ingestDocumentWrappers , inferenceList , dataForInferences , handler );
114
+ return ;
115
+ }
116
+ // skipExisting flag is turned on, eligible inference texts in dataForInferences will be compared and filtered after embeddings
117
+ // are copied
118
+ openSearchClient .execute (
119
+ MultiGetAction .INSTANCE ,
120
+ buildMultiGetRequest (dataForInferences ),
121
+ ActionListener .wrap (
122
+ response -> multiGetResponseHandler (
123
+ response ,
124
+ ingestDocumentWrappers ,
125
+ inferenceList ,
126
+ dataForInferences ,
127
+ handler ,
128
+ textEmbeddingInferenceFilter
129
+ ),
130
+ e -> {
131
+ // When exception is thrown in for MultiGetAction, set exception to all ingestDocumentWrappers
132
+ updateWithExceptions (getIngestDocumentWrappers (dataForInferences ), handler , e );
133
+ }
134
+ )
135
+ );
136
+ } catch (Exception e ) {
137
+ updateWithExceptions (ingestDocumentWrappers , handler , e );
138
+ }
139
+ }
66
140
}
0 commit comments