Skip to content

Commit fa9ff0b

Browse files
junqiu-leiYeonghyeonKO
authored andcommitted
Support custom tags in semantic highlighter (opensearch-project#1254)
Signed-off-by: yeonghyeonKo <[email protected]>
1 parent 4466fc3 commit fa9ff0b

File tree

8 files changed

+319
-61
lines changed

8 files changed

+319
-61
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1414
- Optimize embedding generation in Sparse Encoding Processor ([#1246](https://github.com/opensearch-project/neural-search/pull/1246))
1515
- Optimize embedding generation in Text/Image Embedding Processor ([#1249](https://github.com/opensearch-project/neural-search/pull/1249))
1616
- Inner hits support with hybrid query ([#1253](https://github.com/opensearch-project/neural-search/pull/1253))
17+
- Support custom tags in semantic highlighter ([#1254](https://github.com/opensearch-project/neural-search/pull/1254))
1718

1819
### Enhancements
1920

src/main/java/org/opensearch/neuralsearch/highlight/SemanticHighlighter.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,18 @@ public HighlightField highlight(FieldHighlightContext fieldContext) {
6060
return null;
6161
}
6262

63+
// The pre- and post- tags are provided by the user or defaulted to <em> and </em>
64+
String[] preTags = fieldContext.field.fieldOptions().preTags();
65+
String[] postTags = fieldContext.field.fieldOptions().postTags();
66+
6367
// Get highlighted text - allow any exceptions from this call to propagate
64-
String highlightedResponse = semanticHighlighterEngine.getHighlightedSentences(modelId, originalQueryText, fieldText);
68+
String highlightedResponse = semanticHighlighterEngine.getHighlightedSentences(
69+
modelId,
70+
originalQueryText,
71+
fieldText,
72+
preTags[0],
73+
postTags[0]
74+
);
6575

6676
if (highlightedResponse == null || highlightedResponse.isEmpty()) {
6777
log.warn("No highlighted text found for field {}", fieldContext.fieldName);

src/main/java/org/opensearch/neuralsearch/highlight/SemanticHighlighterEngine.java

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.opensearch.neuralsearch.highlight.extractor.QueryTextExtractorRegistry;
1414
import org.opensearch.action.support.PlainActionFuture;
1515
import lombok.NonNull;
16+
import lombok.Builder;
1617

1718
import java.util.ArrayList;
1819
import java.util.List;
@@ -24,24 +25,18 @@
2425
* Engine class for semantic highlighting operations
2526
*/
2627
@Log4j2
28+
@Builder
2729
public class SemanticHighlighterEngine {
2830
private static final String MODEL_ID_FIELD = "model_id";
29-
private static final String DEFAULT_PRE_TAG = "<em>";
30-
private static final String DEFAULT_POST_TAG = "</em>";
3131
private static final String MODEL_INFERENCE_RESULT_KEY = "highlights";
3232
private static final String MODEL_INFERENCE_RESULT_START_KEY = "start";
3333
private static final String MODEL_INFERENCE_RESULT_END_KEY = "end";
3434

35+
@NonNull
3536
private final MLCommonsClientAccessor mlCommonsClient;
36-
private final QueryTextExtractorRegistry queryTextExtractorRegistry;
3737

38-
public SemanticHighlighterEngine(
39-
@NonNull MLCommonsClientAccessor mlCommonsClient,
40-
@NonNull QueryTextExtractorRegistry queryTextExtractorRegistry
41-
) {
42-
this.mlCommonsClient = mlCommonsClient;
43-
this.queryTextExtractorRegistry = queryTextExtractorRegistry;
44-
}
38+
@NonNull
39+
private final QueryTextExtractorRegistry queryTextExtractorRegistry;
4540

4641
/**
4742
* Gets the field text from the document
@@ -116,15 +111,17 @@ public String getModelId(Map<String, Object> options) {
116111
* @param modelId The ID of the model to use
117112
* @param question The search query
118113
* @param context The document text
114+
* @param preTag The pre tag to use for highlighting
115+
* @param postTag The post tag to use for highlighting
119116
* @return Formatted text with highlighting
120117
*/
121-
public String getHighlightedSentences(String modelId, String question, String context) {
118+
public String getHighlightedSentences(String modelId, String question, String context, String preTag, String postTag) {
122119
List<Map<String, Object>> results = fetchModelResults(modelId, question, context);
123120
if (results == null || results.isEmpty()) {
124121
return null;
125122
}
126123

127-
return applyHighlighting(context, results.getFirst());
124+
return applyHighlighting(context, results.getFirst(), preTag, postTag);
128125
}
129126

130127
/**
@@ -168,10 +165,12 @@ public List<Map<String, Object>> fetchModelResults(String modelId, String questi
168165
*
169166
* @param context The original document text
170167
* @param highlightResult The highlighting result from the ML model
168+
* @param preTag The pre tag to use for highlighting
169+
* @param postTag The post tag to use for highlighting
171170
* @return Formatted text with highlighting
172171
* @throws IllegalArgumentException if highlight positions are invalid
173172
*/
174-
public String applyHighlighting(String context, Map<String, Object> highlightResult) {
173+
public String applyHighlighting(String context, Map<String, Object> highlightResult, String preTag, String postTag) {
175174
// Get the "highlights" list from the result
176175
Object highlightsObj = highlightResult.get(MODEL_INFERENCE_RESULT_KEY);
177176

@@ -216,7 +215,7 @@ public String applyHighlighting(String context, Map<String, Object> highlightRes
216215
}
217216
}
218217

219-
return constructHighlightedText(context, validHighlights);
218+
return constructHighlightedText(context, validHighlights, preTag, postTag);
220219
}
221220

222221
/**
@@ -246,9 +245,11 @@ private void validateHighlightPositions(int start, int end, int textLength) {
246245
*
247246
* @param text The original text
248247
* @param highlights The list of valid highlight positions in pairs [start1, end1, start2, end2, ...]
248+
* @param preTag The pre tag to use for highlighting
249+
* @param postTag The post tag to use for highlighting
249250
* @return The highlighted text
250251
*/
251-
private String constructHighlightedText(String text, List<Integer> highlights) {
252+
private String constructHighlightedText(String text, List<Integer> highlights, String preTag, String postTag) {
252253
StringBuilder result = new StringBuilder();
253254
int currentPos = 0;
254255

@@ -264,9 +265,9 @@ private String constructHighlightedText(String text, List<Integer> highlights) {
264265
}
265266

266267
// Add the highlighted text with highlight tags
267-
result.append(DEFAULT_PRE_TAG);
268+
result.append(preTag);
268269
result.append(text, start, end);
269-
result.append(DEFAULT_POST_TAG);
270+
result.append(postTag);
270271

271272
// Update current position to end of this highlight
272273
currentPos = end;

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ public Collection<Object> createComponents(
114114
NeuralQueryBuilder.initialize(clientAccessor);
115115
NeuralSparseQueryBuilder.initialize(clientAccessor);
116116
QueryTextExtractorRegistry queryTextExtractorRegistry = new QueryTextExtractorRegistry();
117-
SemanticHighlighterEngine semanticHighlighterEngine = new SemanticHighlighterEngine(clientAccessor, queryTextExtractorRegistry);
117+
SemanticHighlighterEngine semanticHighlighterEngine = SemanticHighlighterEngine.builder()
118+
.mlCommonsClient(clientAccessor)
119+
.queryTextExtractorRegistry(queryTextExtractorRegistry)
120+
.build();
118121
semanticHighlighter.initialize(semanticHighlighterEngine);
119122
HybridQueryExecutor.initialize(threadPool);
120123
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());

src/test/java/org/opensearch/neuralsearch/highlight/SemanticHighlighterEngineTests.java

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ public void setUp() throws Exception {
4545
super.setUp();
4646
mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
4747
queryTextExtractorRegistry = new QueryTextExtractorRegistry();
48-
highlighterEngine = new SemanticHighlighterEngine(mlCommonsClientAccessor, queryTextExtractorRegistry);
48+
highlighterEngine = SemanticHighlighterEngine.builder()
49+
.mlCommonsClient(mlCommonsClientAccessor)
50+
.queryTextExtractorRegistry(queryTextExtractorRegistry)
51+
.build();
4952

5053
// Setup default mock behavior
5154
setupDefaultMockBehavior();
@@ -123,7 +126,7 @@ public void testExtractOriginalQuery() {
123126
}
124127

125128
public void testGetHighlightedSentences() {
126-
String result = highlighterEngine.getHighlightedSentences(MODEL_ID, TEST_QUERY, TEST_CONTENT);
129+
String result = highlighterEngine.getHighlightedSentences(MODEL_ID, TEST_QUERY, TEST_CONTENT, "<em>", "</em>");
127130

128131
assertNotNull("Should return highlighted text", result);
129132
assertTrue("Should contain highlighting tags", result.contains("<em>") && result.contains("</em>"));
@@ -149,7 +152,7 @@ public void testApplyHighlighting() {
149152
resultMap.put("highlights", highlightsList);
150153

151154
String text = "This is a test string";
152-
String result = highlighterEngine.applyHighlighting(text, resultMap);
155+
String result = highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>");
153156

154157
assertEquals("Should apply highlights correctly", "<em>This</em> is <em>a te</em>st string", result);
155158
}
@@ -170,7 +173,10 @@ public void testApplyHighlightingWithInvalidPositions() {
170173
highlightsList.add(highlight1);
171174
resultMap.put("highlights", highlightsList);
172175

173-
OpenSearchException exception = expectThrows(OpenSearchException.class, () -> highlighterEngine.applyHighlighting(text, resultMap));
176+
OpenSearchException exception = expectThrows(
177+
OpenSearchException.class,
178+
() -> highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>")
179+
);
174180
assertEquals(
175181
"Should throw correct error message for invalid positions",
176182
String.format(
@@ -188,7 +194,7 @@ public void testApplyHighlightingWithInvalidPositions() {
188194
highlight2.put("end", 100);
189195
highlightsList.add(highlight2);
190196

191-
exception = expectThrows(OpenSearchException.class, () -> highlighterEngine.applyHighlighting(text, resultMap));
197+
exception = expectThrows(OpenSearchException.class, () -> highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>"));
192198
assertEquals(
193199
"Should throw correct error message for invalid positions",
194200
String.format(
@@ -206,7 +212,7 @@ public void testApplyHighlightingWithInvalidPositions() {
206212
highlight3.put("end", 5);
207213
highlightsList.add(highlight3);
208214

209-
exception = expectThrows(OpenSearchException.class, () -> highlighterEngine.applyHighlighting(text, resultMap));
215+
exception = expectThrows(OpenSearchException.class, () -> highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>"));
210216
assertEquals(
211217
"Should throw correct error message for invalid positions",
212218
String.format(
@@ -238,20 +244,15 @@ public void testApplyHighlightingWithUnsortedPositions() {
238244

239245
resultMap.put("highlights", highlightsList);
240246

241-
OpenSearchException exception = expectThrows(OpenSearchException.class, () -> highlighterEngine.applyHighlighting(text, resultMap));
247+
OpenSearchException exception = expectThrows(
248+
OpenSearchException.class,
249+
() -> highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>")
250+
);
242251
assertEquals(
243252
"Should throw correct error message for unsorted positions",
244253
"Internal error while applying semantic highlight: received unsorted highlights from model",
245254
exception.getMessage()
246255
);
247-
// Verify that sorted positions work correctly
248-
highlightsList.clear();
249-
// Add highlights in sorted order
250-
highlightsList.add(highlight2); // start=0
251-
highlightsList.add(highlight1); // start=8
252-
253-
String result = highlighterEngine.applyHighlighting(text, resultMap);
254-
assertEquals("Should successfully highlight with sorted positions", "<em>This</em> is <em>a</em> test string", result);
255256
}
256257

257258
public void testApplyHighlightingWithInvalidHighlightMap() {
@@ -268,7 +269,10 @@ public void testApplyHighlightingWithInvalidHighlightMap() {
268269
resultMap.put("highlights", highlightsList);
269270

270271
String text = "This is a test string";
271-
ClassCastException exception = expectThrows(ClassCastException.class, () -> highlighterEngine.applyHighlighting(text, resultMap));
272+
ClassCastException exception = expectThrows(
273+
ClassCastException.class,
274+
() -> highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>")
275+
);
272276
assertTrue(exception.getMessage().contains("cannot be cast to class java.lang.Number"));
273277
}
274278

@@ -286,7 +290,10 @@ public void testApplyHighlightingWithMissingPositions() {
286290
resultMap.put("highlights", highlightsList);
287291

288292
String text = "This is a test string";
289-
OpenSearchException exception = expectThrows(OpenSearchException.class, () -> highlighterEngine.applyHighlighting(text, resultMap));
293+
OpenSearchException exception = expectThrows(
294+
OpenSearchException.class,
295+
() -> highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>")
296+
);
290297
assertTrue(exception.getMessage().contains("Missing start or end position"));
291298
}
292299

@@ -296,17 +303,31 @@ public void testApplyHighlightingWithEmptyHighlights() {
296303
resultMap.put("highlights", new ArrayList<>());
297304

298305
String text = "This is a test string";
299-
String result = highlighterEngine.applyHighlighting(text, resultMap);
306+
String result = highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>");
300307
assertEquals("Should return original text when no highlights", text, result);
301308
}
302309

303310
public void testApplyHighlightingWithMissingHighlightsKey() {
304-
// Test with missing highlights key
305311
Map<String, Object> resultMap = new HashMap<>();
306-
// No highlights key
307-
308312
String text = "This is a test string";
309-
String result = highlighterEngine.applyHighlighting(text, resultMap);
313+
String result = highlighterEngine.applyHighlighting(text, resultMap, "<em>", "</em>");
310314
assertNull("Should return null when highlights key is missing", result);
311315
}
316+
317+
public void testCustomTags() {
318+
// Test with custom tags
319+
String result = highlighterEngine.getHighlightedSentences(MODEL_ID, TEST_QUERY, TEST_CONTENT, "<mark>", "</mark>");
320+
assertNotNull("Should return highlighted text", result);
321+
assertTrue("Should contain custom highlighting tags", result.contains("<mark>") && result.contains("</mark>"));
322+
assertFalse("Should not contain default highlighting tags", result.contains("<em>") || result.contains("</em>"));
323+
324+
// Test with different custom tags
325+
result = highlighterEngine.getHighlightedSentences(MODEL_ID, TEST_QUERY, TEST_CONTENT, "<span class='highlight'>", "</span>");
326+
assertNotNull("Should return highlighted text", result);
327+
assertTrue(
328+
"Should contain new custom highlighting tags",
329+
result.contains("<span class='highlight'>") && result.contains("</span>")
330+
);
331+
assertFalse("Should not contain previous highlighting tags", result.contains("<mark>") || result.contains("</mark>"));
332+
}
312333
}

src/test/java/org/opensearch/neuralsearch/highlight/SemanticHighlighterIT.java

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ public class SemanticHighlighterIT extends BaseNeuralSearchIT {
3333
private static final String TEST_CONTENT = "Machine learning is a field of artificial intelligence that uses statistical techniques. "
3434
+ "Natural language processing is a branch of artificial intelligence that helps computers understand human language. "
3535
+ "Deep learning is a subset of machine learning that uses neural networks with many layers.";
36+
private static final String DEFAULT_PRE_TAG = "<em>";
37+
private static final String DEFAULT_POST_TAG = "</em>";
38+
private static final String CUSTOM_PRE_TAG = "<test>";
39+
private static final String CUSTOM_POST_TAG = "</test>";
3640
private final float[] testVector = createRandomVector(TEST_DIMENSION);
3741

3842
@Before
@@ -73,7 +77,7 @@ private void initializeTestIndex() {
7377
* 5. Neural Query
7478
* 6. Hybrid Query
7579
*/
76-
public void testQueriesWithSemanticHighlighter() throws Exception {
80+
public void testQueriesWithSemanticHighlighter() {
7781
// Set up models for the test
7882
String textEmbeddingModelId = prepareModel();
7983
String sentenceHighlightingModelId = prepareSentenceHighlightingModel();
@@ -222,7 +226,9 @@ public void testQueriesWithSemanticHighlighter() throws Exception {
222226
* },
223227
* "options": {
224228
* "model_id": "sentence-highlighting-model-id"
225-
* }
229+
* },
230+
* "pre_tags": ["<test>"],
231+
* "post_tags": ["</test>"]
226232
* }
227233
* }
228234
*/
@@ -232,6 +238,8 @@ public void testQueriesWithSemanticHighlighter() throws Exception {
232238
.modelId(textEmbeddingModelId)
233239
.k(1)
234240
.build();
241+
242+
// First test with default highlighting tags
235243
searchResponse = searchWithSemanticHighlighter(
236244
TEST_BASIC_INDEX_NAME,
237245
neuralQueryBuilder,
@@ -241,6 +249,40 @@ public void testQueriesWithSemanticHighlighter() throws Exception {
241249
);
242250
verifyHighlightResults(searchResponse, TEST_QUERY_TEXT);
243251

252+
// Then test with custom highlighting tags
253+
Map<String, Map<String, Object>> customHighlightFields = Map.of(TEST_TEXT_FIELD_NAME, Map.of("type", "semantic"));
254+
Map<String, Object> customHighlightOptions = Map.of("model_id", sentenceHighlightingModelId);
255+
256+
searchResponse = searchWithHighlight(
257+
TEST_BASIC_INDEX_NAME,
258+
neuralQueryBuilder,
259+
10,
260+
customHighlightFields,
261+
customHighlightOptions,
262+
new String[] { CUSTOM_PRE_TAG },
263+
new String[] { CUSTOM_POST_TAG }
264+
);
265+
266+
// Verify results with custom tags
267+
Map<String, Object> customTagsHit = getFirstInnerHit(searchResponse);
268+
assertNotNull("Search response should contain hits", customTagsHit);
269+
270+
@SuppressWarnings("unchecked")
271+
Map<String, Object> customHighlight = (Map<String, Object>) customTagsHit.get("highlight");
272+
assertNotNull("Hit should contain highlight section", customHighlight);
273+
274+
@SuppressWarnings("unchecked")
275+
List<String> customHighlightedFields = (List<String>) customHighlight.get(TEST_TEXT_FIELD_NAME);
276+
assertNotNull("Highlight should contain the requested field", customHighlightedFields);
277+
assertFalse("Highlighted fields should not be empty", customHighlightedFields.isEmpty());
278+
279+
String customHighlightedText = customHighlightedFields.getFirst();
280+
assertTrue("Text should contain custom opening tag", customHighlightedText.contains(CUSTOM_PRE_TAG));
281+
assertTrue("Text should contain custom closing tag", customHighlightedText.contains(CUSTOM_POST_TAG));
282+
assertFalse("Text should not contain default opening tag", customHighlightedText.contains(DEFAULT_PRE_TAG));
283+
assertFalse("Text should not contain default closing tag", customHighlightedText.contains(DEFAULT_POST_TAG));
284+
verifyHighlightResults(searchResponse, TEST_QUERY_TEXT);
285+
244286
// 6. Test Hybrid Query
245287
/*
246288
* example:
@@ -323,10 +365,9 @@ private void verifyHighlightResults(Map<String, Object> searchResponse, String e
323365

324366
assertTrue("Highlighted text should contain semantically relevant content for query: " + expectedContent, hasRelevantContent);
325367

326-
// Verify the highlight tags are present
327-
assertTrue(
328-
"Highlighted text should contain proper highlight tags",
329-
highlightedText.contains("<em>") && highlightedText.contains("</em>")
330-
);
368+
// Verify the highlight tags are present - either default tags or custom tags
369+
boolean hasDefaultTags = highlightedText.contains(DEFAULT_PRE_TAG) && highlightedText.contains(DEFAULT_POST_TAG);
370+
boolean hasCustomTags = highlightedText.contains(CUSTOM_PRE_TAG) && highlightedText.contains(CUSTOM_POST_TAG);
371+
assertTrue("Highlighted text should contain either default or custom highlight tags", hasDefaultTags || hasCustomTags);
331372
}
332373
}

0 commit comments

Comments
 (0)