Skip to content

Commit c02f53d

Browse files
committed
add rerank prompt
1 parent 8edd2cf commit c02f53d

File tree

6 files changed

+52
-20
lines changed

6 files changed

+52
-20
lines changed

common/common.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,8 +907,12 @@ struct common_init_result common_init_from_params(common_params & params) {
907907

908908
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
909909
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
910+
bool has_rerank_prompt = llama_model_chat_template(model, "rerank_prefix") != NULL ||
911+
llama_model_chat_template(model, "rerank_suffix") != NULL;
910912

911-
if (!has_eos && !has_sep) {
913+
if (has_rerank_prompt) {
914+
// OK, do nothing
915+
} else if (!has_eos && !has_sep) {
912916
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
913917
ok = false;
914918
} else if (!has_eos) {

convert_hf_to_gguf.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3073,6 +3073,7 @@ class Qwen3Model(Qwen2Model):
30733073
def __init__(self, *args, **kwargs):
30743074
super().__init__(*args, **kwargs)
30753075
# a bit hacky, but currently the only way to detect if this is a rerank model
3076+
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
30763077
readme_path = self.dir_model / "README.md"
30773078
readme_text = ""
30783079
if readme_path.exists():
@@ -3086,7 +3087,6 @@ def _find_rerank_config(self):
30863087
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
30873088
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
30883089
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
3089-
self.sep_token_id = tokenizer.convert_tokens_to_ids("\\n") # unused, but needed for rerank check
30903090
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
30913091
logger.info(f"gguf: token_false_id = {self.token_false_id}, token_true_id = {self.token_true_id}")
30923092
logger.info(f"gguf: sep_token_id = {self.sep_token_id}")
@@ -3097,8 +3097,14 @@ def set_gguf_parameters(self):
30973097
is_rerank = self.token_false_id is not None and self.token_true_id is not None
30983098
if is_rerank:
30993099
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
3100-
self.gguf_writer.add_sep_token_id(self.sep_token_id)
31013100
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
3101+
self.gguf_writer.add_chat_template([{
3102+
"name": "rerank_prefix",
3103+
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n",
3104+
}, {
3105+
"name": "rerank_suffix",
3106+
"template": "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
3107+
}])
31023108

31033109
def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
31043110
# extract "yes" and "no" tokens from the output lm_head tensor

src/llama-arch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
200200
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
201201
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
202202
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
203-
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
203+
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template." }, // FIXME: cannot add %s because it will be replaced by arch name
204204
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
205205
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
206206
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },

src/llama-model.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13793,7 +13793,8 @@ uint64_t llama_model_size(const llama_model * model) {
1379313793
}
1379413794

1379513795
const char * llama_model_chat_template(const llama_model * model, const char * name) {
13796-
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
13796+
const auto key = name
13797+
? LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + std::string(name)
1379713798
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
1379813799
const auto & it = model->gguf_kv.find(key);
1379913800
if (it == model->gguf_kv.end()) {

tools/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4715,7 +4715,7 @@ int main(int argc, char ** argv) {
47154715
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
47164716
tasks.reserve(tokenized_docs.size());
47174717
for (size_t i = 0; i < tokenized_docs.size(); i++) {
4718-
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
4718+
auto tmp = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
47194719
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
47204720
task.id = ctx_server.queue_tasks.get_new_id();
47214721
task.index = i;

tools/server/utils.hpp

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -260,23 +260,44 @@ static size_t validate_utf8(const std::string& text) {
260260
// template utils
261261
//
262262

263-
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
264-
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
263+
// format rerank task:
264+
// - using SEP token: [BOS]query[EOS][SEP]doc[EOS]
265+
// - using prompt: <rerank_prefix>query<rerank_suffix>doc
266+
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
267+
const llama_vocab * vocab = llama_model_get_vocab(model);
265268
llama_tokens result;
266269

267-
// Get EOS token - use SEP token as fallback if EOS is not available
268-
llama_token eos_token = llama_vocab_eos(vocab);
269-
if (eos_token == LLAMA_TOKEN_NULL) {
270-
eos_token = llama_vocab_sep(vocab);
271-
}
270+
if (llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL) {
271+
// Get EOS token - use SEP token as fallback if EOS is not available
272+
llama_token eos_token = llama_vocab_eos(vocab);
273+
if (eos_token == LLAMA_TOKEN_NULL) {
274+
eos_token = llama_vocab_sep(vocab);
275+
}
276+
277+
result.reserve(doc.size() + query.size() + 4);
278+
result.push_back(llama_vocab_bos(vocab));
279+
result.insert(result.end(), query.begin(), query.end());
280+
result.push_back(eos_token);
281+
result.push_back(llama_vocab_sep(vocab));
282+
result.insert(result.end(), doc.begin(), doc.end());
283+
result.push_back(eos_token);
284+
} else {
285+
// using prompt template
286+
const char * prefix = llama_model_chat_template(model, "rerank_prefix");
287+
const char * suffix = llama_model_chat_template(model, "rerank_suffix");
288+
289+
if (prefix == NULL && suffix == NULL) {
290+
throw std::runtime_error("Rerank prompt template not found in the model\n");
291+
}
272292

273-
result.reserve(doc.size() + query.size() + 4);
274-
result.push_back(llama_vocab_bos(vocab));
275-
result.insert(result.end(), query.begin(), query.end());
276-
result.push_back(eos_token);
277-
result.push_back(llama_vocab_sep(vocab));
278-
result.insert(result.end(), doc.begin(), doc.end());
279-
result.push_back(eos_token);
293+
const llama_tokens prefix_tokens = prefix ? common_tokenize(vocab, prefix, true, false) : llama_tokens();
294+
const llama_tokens suffix_tokens = suffix ? common_tokenize(vocab, suffix, false, false) : llama_tokens();
295+
result.reserve(prefix_tokens.size() + query.size() + suffix_tokens.size() + doc.size());
296+
result.insert(result.end(), prefix_tokens.begin(), prefix_tokens.end());
297+
result.insert(result.end(), query.begin(), query.end());
298+
result.insert(result.end(), suffix_tokens.begin(), suffix_tokens.end());
299+
result.insert(result.end(), doc.begin(), doc.end());
300+
}
280301

281302
return result;
282303
}

0 commit comments

Comments
 (0)