Skip to content

Update highlighting model translator to adapt new model #3699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,31 @@ public final class QAConstants {
// Context keys
public static final String KEY_SENTENCES = "sentences";

// Sentence highlighting model predict chunk input key
public static final String HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY = "chunk";
public static final String HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING = "0";

// Model input names
public static final String INPUT_IDS = "input_ids";
public static final String ATTENTION_MASK = "attention_mask";
public static final String TOKEN_TYPE_IDS = "token_type_ids";
public static final String SENTENCE_IDS = "sentence_ids";

// Default values for warm-up
public static final String DEFAULT_WARMUP_QUESTION = "How is the weather?";
public static final String DEFAULT_WARMUP_CONTEXT = "The weather is nice, it is beautiful day. The sun is shining. The sky is blue.";

// Default model configuration
public static final String TOKEN_MAX_LENGTH_KEY = "token_max_length";
public static final Integer DEFAULT_TOKEN_MAX_LENGTH = 512;
public static final String TOKEN_OVERLAP_STRIDE_LENGTH_KEY = "token_overlap_stride";
public static final Integer DEFAULT_TOKEN_OVERLAP_STRIDE_LENGTH = 128;
public static final String WITH_OVERFLOWING_TOKENS_KEY = "with_overflowing_tokens";
public static final Boolean DEFAULT_WITH_OVERFLOWING_TOKENS = true;
public static final String PADDING_KEY = "padding";
public static final Boolean DEFAULT_PADDING = false;
public static final String TOKENIZER_FILE_NAME = "tokenizer.json";
// Special token value used to ignore tokens in sentence ID mapping
public static final int IGNORE_TOKEN_ID = -100;
public static final int CONTEXT_START_DEFAULT_INDEX = 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,41 @@

package org.opensearch.ml.engine.algorithms.question_answering;

import static org.opensearch.ml.engine.ModelHelper.PYTORCH_ENGINE;
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.DEFAULT_WARMUP_CONTEXT;
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.DEFAULT_WARMUP_QUESTION;
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_HIGHLIGHTS;
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_POSITION;
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY;
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING;
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.SENTENCE_HIGHLIGHTING_TYPE;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.annotation.Function;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;

/**
Expand All @@ -36,7 +48,12 @@
*/
@Log4j2
@Function(FunctionName.QUESTION_ANSWERING)
@Builder(toBuilder = true)
@AllArgsConstructor
@NoArgsConstructor
public class QuestionAnsweringModel extends DLModel {
private MLModelConfig modelConfig;
private Translator<Input, Output> translator;

@Override
public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
Expand All @@ -47,50 +64,193 @@ public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfi
throw new IllegalArgumentException("model id is null");
}

// Initialize model config from model if it exists, the model config field is required for sentence highlighting model.
if (modelConfig != null) {
this.modelConfig = modelConfig;
}

// Create input for the predictor
Input input = new Input();
input.add(DEFAULT_WARMUP_QUESTION);
input.add(DEFAULT_WARMUP_CONTEXT);

if (isSentenceHighlightingModel()) {
input.add(MLInput.QUESTION_FIELD, DEFAULT_WARMUP_QUESTION);
input.add(MLInput.CONTEXT_FIELD, DEFAULT_WARMUP_CONTEXT);
// Add initial chunk number key value pair which is needed for sentence highlighting model
input.add(HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY, HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING);
} else {
input.add(DEFAULT_WARMUP_QUESTION);
input.add(DEFAULT_WARMUP_CONTEXT);
}

// Run prediction to warm up the model
predictor.predict(input);
}

/**
* Checks if the model is configured for sentence highlighting.
*
* @param modelConfig The model configuration
* @return true if the model is configured for sentence highlighting, false otherwise
*/
private boolean isSentenceHighlightingType(MLModelConfig modelConfig) {
if (modelConfig != null) {
return SENTENCE_HIGHLIGHTING_TYPE.equalsIgnoreCase(modelConfig.getModelType());
}
return false;
}

@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
MLInputDataset inputDataSet = mlInput.getInputDataset();
List<ModelTensors> tensorOutputs = new ArrayList<>();
Output output;
QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet) inputDataSet;
String question = qaInputDataSet.getQuestion();
String context = qaInputDataSet.getContext();

if (isSentenceHighlightingModel()) {
return predictSentenceHighlightingQA(question, context);
}

return predictStandardQA(question, context);
}

private boolean isSentenceHighlightingModel() {
return modelConfig != null && SENTENCE_HIGHLIGHTING_TYPE.equalsIgnoreCase(modelConfig.getModelType());
}

private ModelTensorOutput predictStandardQA(String question, String context) throws TranslateException {
Input input = new Input();
input.add(question);
input.add(context);
output = getPredictor().predict(input);
tensorOutputs.add(parseModelTensorOutput(output, null));
return new ModelTensorOutput(tensorOutputs);

try {
Output output = getPredictor().predict(input);
ModelTensors tensors = parseModelTensorOutput(output, null);
return new ModelTensorOutput(List.of(tensors));
} catch (Exception e) {
log.error("Error processing standard QA model prediction", e);
throw new TranslateException("Failed to process standard QA model prediction", e);
}
}

private ModelTensorOutput predictSentenceHighlightingQA(String question, String context) throws TranslateException {
SentenceHighlightingQATranslator translator = (SentenceHighlightingQATranslator) getTranslator(PYTORCH_ENGINE, this.modelConfig);

try {
List<Map<String, Object>> allHighlights = new ArrayList<>();

// We need to process initial chunk first to get the overflow encodings
processChunk(question, context, HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING, allHighlights);

Encoding encodings = translator.getTokenizer().encode(question, context);
Encoding[] overflowEncodings = encodings.getOverflowing();

// Process overflow chunks if overflow encodings are present
if (overflowEncodings != null && overflowEncodings.length > 0) {
for (int i = 0; i < overflowEncodings.length; i++) {
processChunk(question, context, String.valueOf(i + 1), allHighlights);
}
}

return createHighlightOutput(allHighlights);
} catch (Exception e) {
log.error("Error processing sentence highlighting model prediction", e);
throw new TranslateException("Failed to process chunks for sentence highlighting", e);
}
}

private void processChunk(String question, String context, String chunkNumber, List<Map<String, Object>> allHighlights)
throws TranslateException {
Input chunkInput = new Input();
chunkInput.add(MLInput.QUESTION_FIELD, question);
chunkInput.add(MLInput.CONTEXT_FIELD, context);
chunkInput.add(HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY, chunkNumber);

// Use batchPredict to process the chunk for complete results, predict only return the first result which can cause loss of relevant
// results
List<Output> outputs = getPredictor().batchPredict(List.of(chunkInput));

if (outputs.isEmpty()) {
return;
}

for (Output output : outputs) {
ModelTensors tensors = parseModelTensorOutput(output, null);
allHighlights.addAll(extractHighlights(tensors));
}
}

/**
* Extract highlights from model tensors output
*
* @param tensors The model tensors to extract highlights from
* @return List of highlight data maps
*/
private List<Map<String, Object>> extractHighlights(ModelTensors tensors) throws TranslateException {
List<Map<String, Object>> highlights = new ArrayList<>();

for (ModelTensor tensor : tensors.getMlModelTensors()) {
Map<String, ?> dataAsMap = tensor.getDataAsMap();
if (dataAsMap != null && dataAsMap.containsKey(FIELD_HIGHLIGHTS)) {
try {
List<Map<String, Object>> tensorHighlights = (List<Map<String, Object>>) dataAsMap.get(FIELD_HIGHLIGHTS);
highlights.addAll(tensorHighlights);
} catch (ClassCastException e) {
log.error("Failed to cast highlights data to expected format", e);
throw new TranslateException("Failed to cast highlights data to expected format", e);
}
}
}

return highlights;
}

/**
* Create a model tensor output for highlights
*
* @param highlights The list of highlights to include
* @return ModelTensorOutput containing highlights
*/
private ModelTensorOutput createHighlightOutput(List<Map<String, Object>> highlights) {
Map<String, Object> combinedData = new HashMap<>();

// Remove duplicates and sort by position
List<Map<String, Object>> uniqueSortedHighlights = removeDuplicatesAndSort(highlights);

combinedData.put(FIELD_HIGHLIGHTS, uniqueSortedHighlights);

ModelTensor combinedTensor = ModelTensor.builder().name(FIELD_HIGHLIGHTS).dataAsMap(combinedData).build();

return new ModelTensorOutput(List.of(new ModelTensors(List.of(combinedTensor))));
}

/**
* Removes duplicate sentences and sorts them by position
*
* @param highlights The list of highlights to process
* @return List of unique highlights sorted by position
*/
private List<Map<String, Object>> removeDuplicatesAndSort(List<Map<String, Object>> highlights) {
// Use a map to detect duplicates by position
Map<Number, Map<String, Object>> uniqueMap = new HashMap<>();

// Add each highlight to the map, using position as the key
for (Map<String, Object> highlight : highlights) {
Number position = (Number) highlight.get(FIELD_POSITION);
if (!uniqueMap.containsKey(position)) {
uniqueMap.put(position, highlight);
}
}

// Convert back to list
List<Map<String, Object>> uniqueHighlights = new ArrayList<>(uniqueMap.values());

// Sort by position
uniqueHighlights.sort((a, b) -> {
Number posA = (Number) a.get(FIELD_POSITION);
Number posB = (Number) b.get(FIELD_POSITION);
return Double.compare(posA.doubleValue(), posB.doubleValue());
});

return uniqueHighlights;
}

@Override
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
if (isSentenceHighlightingType(modelConfig)) {
return SentenceHighlightingQATranslator.createDefault();
if (translator == null) {
if (modelConfig != null && SENTENCE_HIGHLIGHTING_TYPE.equalsIgnoreCase(modelConfig.getModelType())) {
translator = SentenceHighlightingQATranslator.create(modelConfig);
} else {
translator = new QuestionAnsweringTranslator();
}
}
return new QuestionAnsweringTranslator();
return translator;
}

@Override
Expand Down
Loading
Loading