Skip to content

Commit c9d06bc

Browse files
bzhangamYeonghyeonKO
authored andcommitted
Add semantic ingest processor. (opensearch-project#1309)
Signed-off-by: Bo Zhang <[email protected]> Signed-off-by: yeonghyeonKo <[email protected]>
1 parent 3d41263 commit c9d06bc

29 files changed

+1971
-131
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
88
### Features
99
- Implement analyzer based neural sparse query ([#1088](https://github.com/opensearch-project/neural-search/pull/1088) [#1279](https://github.com/opensearch-project/neural-search/pull/1279))
1010
- [Semantic Field] Add semantic mapping transformer. ([#1276](https://github.com/opensearch-project/neural-search/pull/1276))
11+
- [Semantic Field] Add semantic ingest processor. ([#1309](https://github.com/opensearch-project/neural-search/pull/1309))
1112

1213
### Enhancements
1314
- [Performance Improvement] Add custom bulk scorer for hybrid query (2-3x faster) ([#1289](https://github.com/opensearch-project/neural-search/pull/1289))

build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ dependencies {
288288
testFixturesImplementation fileTree(dir: knnJarDirectory, include: ["opensearch-knn-${opensearch_build}.jar", "remote-index-build-client-${opensearch_build}.jar"])
289289
testImplementation fileTree(dir: knnJarDirectory, include: ["opensearch-knn-${opensearch_build}.jar", "remote-index-build-client-${opensearch_build}.jar"])
290290
testImplementation "org.opensearch.plugin:parent-join-client:${opensearch_version}"
291+
testImplementation 'org.assertj:assertj-core:3.24.2'
291292
}
292293

293294
// In order to add the jar to the classpath, we need to unzip the

src/main/java/org/opensearch/neuralsearch/constants/MappingConstants.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,9 @@ public class MappingConstants {
2121
* Name for properties. An object field will define subfields as properties.
2222
*/
2323
public static final String PROPERTIES = "properties";
24+
25+
/**
26+
* Separator in a field path.
27+
*/
28+
public static final String PATH_SEPARATOR = ".";
2429
}

src/main/java/org/opensearch/neuralsearch/mappingtransformer/SemanticMappingTransformer.java

Lines changed: 6 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,15 @@
2121
import java.util.Locale;
2222
import java.util.Map;
2323
import java.util.Set;
24-
import java.util.concurrent.ConcurrentHashMap;
25-
import java.util.concurrent.atomic.AtomicBoolean;
26-
import java.util.concurrent.atomic.AtomicInteger;
2724
import java.util.stream.Collectors;
2825

2926
import static org.opensearch.neuralsearch.constants.MappingConstants.PROPERTIES;
30-
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.MODEL_ID;
3127
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SEMANTIC_INFO_FIELD_NAME;
3228
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.collectSemanticField;
3329
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.extractModelIdToFieldPathMap;
3430
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.getProperties;
31+
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.validateModelId;
32+
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.validateSemanticInfoFieldName;
3533

3634
/**
3735
* SemanticMappingTransformer transforms the index mapping for the semantic field to auto add the semantic info fields
@@ -168,101 +166,17 @@ private void validateSemanticFields(@NonNull final Map<String, Map<String, Objec
168166
}
169167
}
170168

171-
private String validateModelId(@NonNull final String semanticFieldPath, @NonNull final Map<String, Object> semanticFieldConfig) {
172-
Object modelId = semanticFieldConfig.get(SemanticFieldConstants.MODEL_ID);
173-
if (modelId == null) {
174-
return String.format(Locale.ROOT, "%s is required for the semantic field at %s", MODEL_ID, semanticFieldPath);
175-
}
176-
177-
if (modelId instanceof String == false || ((String) modelId).isEmpty()) {
178-
return String.format(Locale.ROOT, "%s should be a non-empty string for the semantic field at %s", MODEL_ID, semanticFieldPath);
179-
}
180-
181-
return null;
182-
}
183-
184-
private String validateSemanticInfoFieldName(
185-
@NonNull final String semanticFieldPath,
186-
@NonNull final Map<String, Object> semanticFieldConfig
187-
) {
188-
if (semanticFieldConfig.containsKey(SEMANTIC_INFO_FIELD_NAME)) {
189-
final Object semanticInfoFieldName = semanticFieldConfig.get(SEMANTIC_INFO_FIELD_NAME);
190-
if (semanticInfoFieldName instanceof String semanticInfoFieldNameStr) {
191-
if (semanticInfoFieldNameStr.isEmpty()) {
192-
return String.format(
193-
Locale.ROOT,
194-
"%s cannot be an empty string for the semantic field at %s",
195-
SEMANTIC_INFO_FIELD_NAME,
196-
semanticFieldPath
197-
198-
);
199-
}
200-
201-
// OpenSearch allows to define a field name with "." in the index mapping and will unflatten it later
202-
// but in our case it's not necessary to support "." in the custom semantic info field name. So add this
203-
// validation to block it.
204-
if (semanticInfoFieldNameStr.contains(".")) {
205-
return String.format(
206-
Locale.ROOT,
207-
"%s should not contain '.' for the semantic field at %s",
208-
SEMANTIC_INFO_FIELD_NAME,
209-
semanticFieldPath
210-
211-
);
212-
}
213-
} else {
214-
return String.format(
215-
Locale.ROOT,
216-
"%s should be a non-empty string for the semantic field at %s",
217-
SEMANTIC_INFO_FIELD_NAME,
218-
semanticFieldPath
219-
220-
);
221-
}
222-
}
223-
// SEMANTIC_INFO_FIELD_NAME is an optional field. If it does not exist we simply return null to show no error.
224-
return null;
225-
}
226-
227169
private void fetchModelAndModifyMapping(
228170
@NonNull final Map<String, Map<String, Object>> semanticFieldPathToConfigMap,
229171
@NonNull final Map<String, Object> mappings,
230172
@NonNull final ActionListener<Void> listener
231173
) {
232174
final Map<String, List<String>> modelIdToFieldPathMap = extractModelIdToFieldPathMap(semanticFieldPathToConfigMap);
233-
if (modelIdToFieldPathMap.isEmpty()) {
234-
listener.onResponse(null);
235-
}
236-
final AtomicInteger counter = new AtomicInteger(modelIdToFieldPathMap.size());
237-
final AtomicBoolean hasError = new AtomicBoolean(false);
238-
final List<String> errors = new ArrayList<>();
239-
final Map<String, MLModel> modelIdToConfigMap = new ConcurrentHashMap<>(modelIdToFieldPathMap.size());
240175

241-
// We can have multiple semantic fields with different model ids, and we should get model config for each model
242-
for (String modelId : modelIdToFieldPathMap.keySet()) {
243-
mlClientAccessor.getModel(modelId, ActionListener.wrap(mlModel -> {
244-
modelIdToConfigMap.put(modelId, mlModel);
245-
if (counter.decrementAndGet() == 0) {
246-
try {
247-
if (hasError.get()) {
248-
listener.onFailure(new RuntimeException(String.join("; ", errors)));
249-
} else {
250-
modifyMappings(modelIdToConfigMap, mappings, modelIdToFieldPathMap, semanticFieldPathToConfigMap);
251-
listener.onResponse(null);
252-
}
253-
} catch (Exception e) {
254-
errors.add(e.getMessage());
255-
listener.onFailure(new RuntimeException(String.join("; ", errors)));
256-
}
257-
}
258-
}, e -> {
259-
hasError.set(true);
260-
errors.add(e.getMessage());
261-
if (counter.decrementAndGet() == 0) {
262-
listener.onFailure(new RuntimeException(String.join("; ", errors)));
263-
}
264-
}));
265-
}
176+
mlClientAccessor.getModels(modelIdToFieldPathMap.keySet(), modelIdToConfigMap -> {
177+
modifyMappings(modelIdToConfigMap, mappings, modelIdToFieldPathMap, semanticFieldPathToConfigMap);
178+
listener.onResponse(null);
179+
}, listener::onFailure);
266180
}
267181

268182
private void modifyMappings(

src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
import java.util.Collections;
1313
import java.util.List;
1414
import java.util.Map;
15+
import java.util.Set;
16+
import java.util.concurrent.ConcurrentHashMap;
17+
import java.util.concurrent.atomic.AtomicBoolean;
18+
import java.util.concurrent.atomic.AtomicInteger;
19+
import java.util.function.Consumer;
1520
import java.util.stream.Collectors;
1621

1722
import org.opensearch.core.action.ActionListener;
@@ -316,6 +321,71 @@ public void getModel(@NonNull final String modelId, @NonNull final ActionListene
316321
retryableGetModel(modelId, 0, listener);
317322
}
318323

324+
/**
325+
* Get model info for multiple model ids. It will send multiple getModel requests to get the model info in parallel.
326+
* It will fail if any one of the get model request fail. Only return the success result if all model info is
327+
* successfully retrieved.
328+
* @param modelIds a set of model ids
329+
* @param onSuccess onSuccess consumer
330+
* @param onFailure onFailure consumer
331+
*/
332+
public void getModels(
333+
@NonNull final Set<String> modelIds,
334+
@NonNull final Consumer<Map<String, MLModel>> onSuccess,
335+
@NonNull final Consumer<Exception> onFailure
336+
) {
337+
if (modelIds.isEmpty()) {
338+
try {
339+
onSuccess.accept(Collections.emptyMap());
340+
} catch (Exception e) {
341+
onFailure.accept(e);
342+
}
343+
return;
344+
}
345+
346+
final Map<String, MLModel> modelMap = new ConcurrentHashMap<>();
347+
final AtomicInteger counter = new AtomicInteger(modelIds.size());
348+
final AtomicBoolean hasError = new AtomicBoolean(false);
349+
final List<String> errors = Collections.synchronizedList(new ArrayList<>());
350+
351+
for (String modelId : modelIds) {
352+
try {
353+
getModel(modelId, ActionListener.wrap(model -> {
354+
modelMap.put(modelId, model);
355+
if (counter.decrementAndGet() == 0) {
356+
if (hasError.get()) {
357+
onFailure.accept(new RuntimeException(String.join(";", errors)));
358+
} else {
359+
try {
360+
onSuccess.accept(modelMap);
361+
} catch (Exception e) {
362+
onFailure.accept(e);
363+
}
364+
}
365+
}
366+
}, e -> { handleGetModelException(hasError, errors, modelId, e, counter, onFailure); }));
367+
} catch (Exception e) {
368+
handleGetModelException(hasError, errors, modelId, e, counter, onFailure);
369+
}
370+
}
371+
372+
}
373+
374+
private void handleGetModelException(
375+
AtomicBoolean hasError,
376+
List<String> errors,
377+
String modelId,
378+
Exception e,
379+
AtomicInteger counter,
380+
@NonNull Consumer<Exception> onFailure
381+
) {
382+
hasError.set(true);
383+
errors.add("Failed to fetch model [" + modelId + "]: " + e.getMessage());
384+
if (counter.decrementAndGet() == 0) {
385+
onFailure.accept(new RuntimeException(String.join(";", errors)));
386+
}
387+
}
388+
319389
private void retryableGetModel(@NonNull final String modelId, final int retryTime, @NonNull final ActionListener<MLModel> listener) {
320390
mlClient.getModel(
321391
modelId,

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.opensearch.index.mapper.MappingTransformer;
2929
import org.opensearch.neuralsearch.mapper.SemanticFieldMapper;
3030
import org.opensearch.neuralsearch.mappingtransformer.SemanticMappingTransformer;
31+
import org.opensearch.neuralsearch.processor.factory.SemanticFieldProcessorFactory;
3132
import org.opensearch.plugins.MapperPlugin;
3233
import org.opensearch.transport.client.Client;
3334
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
@@ -300,4 +301,17 @@ public Map<String, Mapper.TypeParser> getMappers() {
300301
public List<MappingTransformer> getMappingTransformers() {
301302
return List.of(new SemanticMappingTransformer(clientAccessor, xContentRegistry));
302303
}
304+
305+
@Override
306+
public Map<String, Processor.Factory> getSystemIngestProcessors(Processor.Parameters parameters) {
307+
return Map.of(
308+
SemanticFieldProcessorFactory.PROCESSOR_FACTORY_TYPE,
309+
new SemanticFieldProcessorFactory(
310+
clientAccessor,
311+
parameters.env,
312+
parameters.ingestService.getClusterService(),
313+
parameters.analysisRegistry
314+
)
315+
);
316+
}
303317
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor.dto;
6+
7+
import lombok.Data;
8+
9+
import java.util.List;
10+
11+
import static org.opensearch.neuralsearch.constants.MappingConstants.PATH_SEPARATOR;
12+
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.CHUNKS_EMBEDDING_FIELD_NAME;
13+
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.CHUNKS_FIELD_NAME;
14+
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.MODEL_FIELD_NAME;
15+
16+
/**
17+
* SemanticFieldInfo is a data transfer object to help hold semantic field info
18+
*/
19+
@Data
20+
public class SemanticFieldInfo {
21+
/**
22+
* The raw string value of the semantic field
23+
*/
24+
private String value;
25+
/**
26+
* The model id of the semantic field which will be used to generate the embedding
27+
*/
28+
private String modelId;
29+
/**
30+
* The full path to the semantic field
31+
*/
32+
private String fullPath;
33+
/**
34+
* The full path to the semantic info fields
35+
*/
36+
private String semanticInfoFullPath;
37+
/**
38+
* The chunked strings of the original string value of the semantic field
39+
*/
40+
private List<String> chunks;
41+
42+
/**
43+
* @return full path to the chunks field of the semantic field
44+
*/
45+
public String getFullPathForChunks() {
46+
return new StringBuilder().append(semanticInfoFullPath).append(PATH_SEPARATOR).append(CHUNKS_FIELD_NAME).toString();
47+
}
48+
49+
/**
50+
* @param index index of the chunk the embedding is in
51+
* @return full path to the embedding field of the semantic field
52+
*/
53+
public String getFullPathForEmbedding(int index) {
54+
return new StringBuilder().append(semanticInfoFullPath)
55+
.append(PATH_SEPARATOR)
56+
.append(CHUNKS_FIELD_NAME)
57+
.append(PATH_SEPARATOR)
58+
.append(index)
59+
.append(PATH_SEPARATOR)
60+
.append(CHUNKS_EMBEDDING_FIELD_NAME)
61+
.toString();
62+
}
63+
64+
/**
65+
* @return full path to the model info fields
66+
*/
67+
public String getFullPathForModelInfo() {
68+
return new StringBuilder().append(semanticInfoFullPath).append(PATH_SEPARATOR).append(MODEL_FIELD_NAME).toString();
69+
}
70+
}

0 commit comments

Comments
 (0)