Skip to content

Commit 030dc3b

Browse files
committed
correct output token position
1 parent e0eb4b8 commit 030dc3b

File tree

3 files changed

+31
-24
lines changed

3 files changed

+31
-24
lines changed

src/llama-graph.cpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,15 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
167167
}
168168

169169
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
170-
if (cparams.embeddings && (
171-
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
172-
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
170+
if (!cparams.embeddings) {
171+
return;
172+
}
173+
174+
const bool is_last_tok = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
175+
arch == LLM_ARCH_QWEN3; // qwen3 reranking & embedding models use last token
176+
177+
if (is_last_tok) {
178+
// set output to the last token of each sequence
173179
const int64_t n_tokens = ubatch->n_tokens;
174180
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
175181
const int64_t n_seqs = ubatch->n_seqs;
@@ -180,23 +186,33 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
180186
uint32_t * data = (uint32_t *) cls->data;
181187
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
182188

189+
std::vector<int> last_pos(n_tokens, -1);
190+
std::vector<int> last_row(n_tokens, -1);
191+
183192
for (int s = 0; s < n_seqs; ++s) {
184193
const llama_seq_id seq_id = ubatch->seq_id[s][0];
185194

186195
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
187-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
196+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
188197

189198
for (int i = 0; i < n_seq_tokens; ++i) {
190199
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
191200

192-
if (pos == 0) {
193-
data[seq_id] = s*n_seq_tokens + i;
201+
if (pos >= last_pos[seq_id]) {
202+
last_pos[seq_id] = pos;
203+
last_row[seq_id] = s*n_seq_tokens + i;
194204
}
195205
}
196206
}
197-
}
198207

199-
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
208+
for (int i = 0; i < n_tokens; ++i) {
209+
if (last_row[i] >= 0) {
210+
data[i] = last_row[i];
211+
}
212+
}
213+
214+
} else {
215+
// set output to first token of each sequence
200216
const int64_t n_tokens = ubatch->n_tokens;
201217
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
202218
const int64_t n_seqs = ubatch->n_seqs;
@@ -207,30 +223,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
207223
uint32_t * data = (uint32_t *) cls->data;
208224
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
209225

210-
std::vector<int> last_pos(n_tokens, -1);
211-
std::vector<int> last_row(n_tokens, -1);
212-
213226
for (int s = 0; s < n_seqs; ++s) {
214227
const llama_seq_id seq_id = ubatch->seq_id[s][0];
215228

216229
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
217-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
230+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
218231

219232
for (int i = 0; i < n_seq_tokens; ++i) {
220233
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
221234

222-
if (pos >= last_pos[seq_id]) {
223-
last_pos[seq_id] = pos;
224-
last_row[seq_id] = s*n_seq_tokens + i;
235+
if (pos == 0) {
236+
data[seq_id] = s*n_seq_tokens + i;
225237
}
226238
}
227239
}
228-
229-
for (int i = 0; i < n_tokens; ++i) {
230-
if (last_row[i] >= 0) {
231-
data[i] = last_row[i];
232-
}
233-
}
234240
}
235241
}
236242

@@ -943,7 +949,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
943949
}
944950

945951
ggml_tensor * llm_graph_context::build_inp_cls() const {
946-
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
952+
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
947953

948954
auto & cur = inp->cls;
949955

src/llama-graph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,14 @@ class llm_graph_input_mean : public llm_graph_input_i {
177177

178178
class llm_graph_input_cls : public llm_graph_input_i {
179179
public:
180-
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
180+
llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : arch(arch), cparams(cparams) {}
181181
virtual ~llm_graph_input_cls() = default;
182182

183183
void set_input(const llama_ubatch * ubatch) override;
184184

185185
ggml_tensor * cls; // I32 [n_batch]
186186

187+
const llm_arch arch;
187188
const llama_cparams & cparams;
188189
};
189190

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7052,7 +7052,7 @@ struct llm_build_qwen3 : public llm_graph_context {
70527052
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
70537053
}
70547054

7055-
if (il == n_layer - 1) {
7055+
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
70567056
// skip computing output for unused tokens
70577057
ggml_tensor * inp_out_ids = build_inp_out_ids();
70587058
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

0 commit comments

Comments
 (0)