Skip to content

Commit 02780be

Browse files
committed
Resolve PR feedback 2
Signed-off-by: Junqiu Lei <[email protected]>
1 parent 6df723b commit 02780be

File tree

7 files changed

+406
-138
lines changed

7 files changed

+406
-138
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModel.java

Lines changed: 61 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.DEFAULT_WARMUP_CONTEXT;
1010
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.DEFAULT_WARMUP_QUESTION;
1111
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_HIGHLIGHTS;
12+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_POSITION;
1213
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY;
1314
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING;
1415
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.SENTENCE_HIGHLIGHTING_TYPE;
@@ -124,11 +125,18 @@ private ModelTensorOutput predictSentenceHighlightingQA(String question, String
124125
try {
125126
List<Map<String, Object>> allHighlights = new ArrayList<>();
126127

127-
// Process initial chunk
128-
processInitialChunk(question, context, allHighlights);
128+
// We need to process initial chunk first to get the overflow encodings
129+
processChunk(question, context, HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING, allHighlights);
129130

130-
// Process overflow chunks if any
131-
processOverflowChunks(question, context, translator, allHighlights);
131+
Encoding encodings = translator.getTokenizer().encode(question, context);
132+
Encoding[] overflowEncodings = encodings.getOverflowing();
133+
134+
// Process overflow chunks if overflow encodings are present
135+
if (overflowEncodings != null && overflowEncodings.length > 0) {
136+
for (int i = 0; i < overflowEncodings.length; i++) {
137+
processChunk(question, context, String.valueOf(i + 1), allHighlights);
138+
}
139+
}
132140

133141
return createHighlightOutput(allHighlights);
134142
} catch (Exception e) {
@@ -137,103 +145,24 @@ private ModelTensorOutput predictSentenceHighlightingQA(String question, String
137145
}
138146
}
139147

140-
private void processInitialChunk(String question, String context, List<Map<String, Object>> allHighlights) throws TranslateException {
141-
Input initialInput = new Input();
142-
initialInput.add(MLInput.QUESTION_FIELD, question);
143-
initialInput.add(MLInput.CONTEXT_FIELD, context);
144-
initialInput.add(HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY, HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING);
145-
146-
List<Output> initialOutputs = getPredictor().batchPredict(List.of(initialInput));
147-
for (Output output : initialOutputs) {
148-
ModelTensors tensors = parseModelTensorOutput(output, null);
149-
allHighlights.addAll(extractHighlights(tensors));
150-
}
151-
}
148+
private void processChunk(String question, String context, String chunkNumber, List<Map<String, Object>> allHighlights)
149+
throws TranslateException {
150+
Input chunkInput = new Input();
151+
chunkInput.add(MLInput.QUESTION_FIELD, question);
152+
chunkInput.add(MLInput.CONTEXT_FIELD, context);
153+
chunkInput.add(HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY, chunkNumber);
152154

153-
private void processOverflowChunks(
154-
String question,
155-
String context,
156-
SentenceHighlightingQATranslator translator,
157-
List<Map<String, Object>> allHighlights
158-
) throws TranslateException {
159-
Encoding encodings = translator.getTokenizer().encode(question, context);
160-
Encoding[] overflowEncodings = encodings.getOverflowing();
155+
// Use batchPredict to process the chunk for complete results, predict only return the first result which can cause loss of relevant
156+
// results
157+
List<Output> outputs = getPredictor().batchPredict(List.of(chunkInput));
161158

162-
if (overflowEncodings == null || overflowEncodings.length == 0) {
159+
if (outputs.isEmpty()) {
163160
return;
164161
}
165162

166-
List<Input> overflowInputs = createOverflowInputs(question, context, overflowEncodings.length);
167-
processOverflowInputs(overflowInputs, allHighlights);
168-
}
169-
170-
private List<Input> createOverflowInputs(String question, String context, int numOverflowChunks) {
171-
List<Input> overflowInputs = new ArrayList<>();
172-
for (int i = 0; i < numOverflowChunks; i++) {
173-
Input chunkInput = new Input();
174-
chunkInput.add(MLInput.QUESTION_FIELD, question);
175-
chunkInput.add(MLInput.CONTEXT_FIELD, context);
176-
chunkInput.add(HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY, String.valueOf(i + 1));
177-
178-
overflowInputs.add(chunkInput);
179-
}
180-
return overflowInputs;
181-
}
182-
183-
private void processOverflowInputs(List<Input> overflowInputs, List<Map<String, Object>> allHighlights) throws TranslateException {
184-
try {
185-
processOverflowInputsBatch(overflowInputs, allHighlights);
186-
} catch (IllegalArgumentException e) {
187-
log.info("Batch processing of chunks failed. Processing chunks individually: {}", e.getMessage());
188-
processOverflowInputsIndividually(overflowInputs, allHighlights);
189-
}
190-
}
191-
192-
private void processOverflowInputsBatch(List<Input> overflowInputs, List<Map<String, Object>> allHighlights) throws TranslateException {
193-
List<Output> overflowOutputs = getPredictor().batchPredict(overflowInputs);
194-
for (Output output : overflowOutputs) {
195-
try {
196-
ModelTensors tensors = parseModelTensorOutput(output, null);
197-
allHighlights.addAll(extractHighlights(tensors));
198-
} catch (Exception e) {
199-
log.warn("Error processing output from chunk", e);
200-
}
201-
}
202-
}
203-
204-
private void processOverflowInputsIndividually(List<Input> overflowInputs, List<Map<String, Object>> allHighlights)
205-
throws TranslateException {
206-
for (int i = 0; i < overflowInputs.size(); i++) {
207-
processOverflowChunkWithBatch(i + 1, overflowInputs.get(i), allHighlights);
208-
}
209-
}
210-
211-
/**
212-
* Process a single overflow chunk using batchPredict and add any extracted highlights
213-
*
214-
* @param chunkIndex The index of the overflow chunk (1-based)
215-
* @param chunkInput The prepared input for this chunk
216-
* @param highlights Collection to add extracted highlights to
217-
*/
218-
private void processOverflowChunkWithBatch(int chunkIndex, Input chunkInput, List<Map<String, Object>> highlights)
219-
throws TranslateException {
220-
try {
221-
// Use batchPredict instead of predict to avoid the bug
222-
List<Output> outputs = getPredictor().batchPredict(List.of(chunkInput));
223-
if (outputs.isEmpty()) {
224-
log.warn("No output returned for chunk {}", chunkIndex);
225-
return;
226-
}
227-
228-
// Process all outputs from this chunk
229-
for (Output output : outputs) {
230-
ModelTensors tensors = parseModelTensorOutput(output, null);
231-
List<Map<String, Object>> chunkHighlights = extractHighlights(tensors);
232-
highlights.addAll(chunkHighlights);
233-
}
234-
} catch (Exception e) {
235-
log.error("Error processing overflow chunk {}", chunkIndex, e);
236-
throw new TranslateException("Failed to process overflow chunk " + chunkIndex, e);
163+
for (Output output : outputs) {
164+
ModelTensors tensors = parseModelTensorOutput(output, null);
165+
allHighlights.addAll(extractHighlights(tensors));
237166
}
238167
}
239168

@@ -270,13 +199,48 @@ private List<Map<String, Object>> extractHighlights(ModelTensors tensors) throws
270199
*/
271200
private ModelTensorOutput createHighlightOutput(List<Map<String, Object>> highlights) {
272201
Map<String, Object> combinedData = new HashMap<>();
273-
combinedData.put(FIELD_HIGHLIGHTS, highlights);
202+
203+
// Remove duplicates and sort by position
204+
List<Map<String, Object>> uniqueSortedHighlights = removeDuplicatesAndSort(highlights);
205+
206+
combinedData.put(FIELD_HIGHLIGHTS, uniqueSortedHighlights);
274207

275208
ModelTensor combinedTensor = ModelTensor.builder().name(FIELD_HIGHLIGHTS).dataAsMap(combinedData).build();
276209

277210
return new ModelTensorOutput(List.of(new ModelTensors(List.of(combinedTensor))));
278211
}
279212

213+
/**
214+
* Removes duplicate sentences and sorts them by position
215+
*
216+
* @param highlights The list of highlights to process
217+
* @return List of unique highlights sorted by position
218+
*/
219+
private List<Map<String, Object>> removeDuplicatesAndSort(List<Map<String, Object>> highlights) {
220+
// Use a map to detect duplicates by position
221+
Map<Number, Map<String, Object>> uniqueMap = new HashMap<>();
222+
223+
// Add each highlight to the map, using position as the key
224+
for (Map<String, Object> highlight : highlights) {
225+
Number position = (Number) highlight.get(FIELD_POSITION);
226+
if (!uniqueMap.containsKey(position)) {
227+
uniqueMap.put(position, highlight);
228+
}
229+
}
230+
231+
// Convert back to list
232+
List<Map<String, Object>> uniqueHighlights = new ArrayList<>(uniqueMap.values());
233+
234+
// Sort by position
235+
uniqueHighlights.sort((a, b) -> {
236+
Number posA = (Number) a.get(FIELD_POSITION);
237+
Number posB = (Number) b.get(FIELD_POSITION);
238+
return Double.compare(posA.doubleValue(), posB.doubleValue());
239+
});
240+
241+
return uniqueHighlights;
242+
}
243+
280244
@Override
281245
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
282246
if (translator == null) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/SentenceHighlightingQATranslator.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,20 @@
6565

6666
/**
6767
* Translator for sentence highlighting question answering model.
68-
*
68+
*
69+
* This translator processes input for semantic sentence highlighting models that identify
70+
* relevant sentences within a context document based on a user query or question.
71+
*
72+
* The translator performs the following key functions:
73+
* 1. Tokenizes the question and context using Hugging Face tokenizer
74+
* 2. Segments the context into sentences
75+
* 3. Maps tokens to their corresponding sentence IDs
76+
* 4. Handles chunking for long contexts that exceed the model's maximum token length
77+
* 5. Processes model outputs to identify and highlight sentences that answer the question
78+
*
79+
* The highlighted sentences are returned with their text and position information within
80+
* the original context, which allows for easy visualization and extraction of relevant
81+
* information from the document.
6982
*/
7083
@Log4j2
7184
@Getter
@@ -166,6 +179,17 @@ public void setArguments(Map<String, ?> arguments) {
166179
// No arguments needed for this translator
167180
}
168181

182+
/**
183+
* Prepares the translator by initializing the tokenizer with the appropriate configuration.
184+
*
185+
* The tokenizer is configured to handle chunking for long contexts that exceed the model's
186+
* maximum token length. Even when processing individual chunks, the full context is always
187+
* passed to the model in the input stage, ensuring that sentence segmentation and token-to-sentence
188+
* mapping is consistent across all chunks.
189+
*
190+
* @param ctx The translator context which provides access to the model path
191+
* @throws IOException If there is an error loading the tokenizer
192+
*/
169193
@Override
170194
public NDList processInput(TranslatorContext ctx, Input input) {
171195
try {
@@ -301,6 +325,13 @@ private int[] createWordLevelSentenceIds(List<Sentence> sentences, String contex
301325
return wordSentenceIds;
302326
}
303327

328+
/**
329+
* Processes the model's output to extract highlighted sentences.
330+
*
331+
* @param ctx The translator context containing sentence information
332+
* @param list The model's output predictions
333+
* @return Formatted output with highlighted sentence details or error information
334+
*/
304335
@Override
305336
public Output processOutput(TranslatorContext ctx, NDList list) {
306337
try {

0 commit comments

Comments
 (0)