@@ -167,9 +167,15 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
167
167
}
168
168
169
169
void llm_graph_input_cls::set_input (const llama_ubatch * ubatch) {
170
- if (cparams.embeddings && (
171
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
172
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
170
+ if (!cparams.embeddings ) {
171
+ return ;
172
+ }
173
+
174
+ const bool is_last_tok = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
175
+ arch == LLM_ARCH_QWEN3; // qwen3 reranking & embedding models use last token
176
+
177
+ if (is_last_tok) {
178
+ // set output to the last token of each sequence
173
179
const int64_t n_tokens = ubatch->n_tokens ;
174
180
const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
175
181
const int64_t n_seqs = ubatch->n_seqs ;
@@ -180,23 +186,33 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
180
186
uint32_t * data = (uint32_t *) cls->data ;
181
187
memset (cls->data , 0 , n_tokens * ggml_element_size (cls));
182
188
189
+ std::vector<int > last_pos (n_tokens, -1 );
190
+ std::vector<int > last_row (n_tokens, -1 );
191
+
183
192
for (int s = 0 ; s < n_seqs; ++s) {
184
193
const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
185
194
186
195
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
187
- GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK " );
196
+ GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == LAST " );
188
197
189
198
for (int i = 0 ; i < n_seq_tokens; ++i) {
190
199
const llama_pos pos = ubatch->pos [s*n_seq_tokens + i];
191
200
192
- if (pos == 0 ) {
193
- data[seq_id] = s*n_seq_tokens + i;
201
+ if (pos >= last_pos[seq_id]) {
202
+ last_pos[seq_id] = pos;
203
+ last_row[seq_id] = s*n_seq_tokens + i;
194
204
}
195
205
}
196
206
}
197
- }
198
207
199
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
208
+ for (int i = 0 ; i < n_tokens; ++i) {
209
+ if (last_row[i] >= 0 ) {
210
+ data[i] = last_row[i];
211
+ }
212
+ }
213
+
214
+ } else {
215
+ // set output to first token of each sequence
200
216
const int64_t n_tokens = ubatch->n_tokens ;
201
217
const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
202
218
const int64_t n_seqs = ubatch->n_seqs ;
@@ -207,30 +223,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
207
223
uint32_t * data = (uint32_t *) cls->data ;
208
224
memset (cls->data , 0 , n_tokens * ggml_element_size (cls));
209
225
210
- std::vector<int > last_pos (n_tokens, -1 );
211
- std::vector<int > last_row (n_tokens, -1 );
212
-
213
226
for (int s = 0 ; s < n_seqs; ++s) {
214
227
const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
215
228
216
229
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
217
- GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == LAST " );
230
+ GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK " );
218
231
219
232
for (int i = 0 ; i < n_seq_tokens; ++i) {
220
233
const llama_pos pos = ubatch->pos [s*n_seq_tokens + i];
221
234
222
- if (pos >= last_pos[seq_id]) {
223
- last_pos[seq_id] = pos;
224
- last_row[seq_id] = s*n_seq_tokens + i;
235
+ if (pos == 0 ) {
236
+ data[seq_id] = s*n_seq_tokens + i;
225
237
}
226
238
}
227
239
}
228
-
229
- for (int i = 0 ; i < n_tokens; ++i) {
230
- if (last_row[i] >= 0 ) {
231
- data[i] = last_row[i];
232
- }
233
- }
234
240
}
235
241
}
236
242
@@ -943,7 +949,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
943
949
}
944
950
945
951
ggml_tensor * llm_graph_context::build_inp_cls () const {
946
- auto inp = std::make_unique<llm_graph_input_cls>(cparams);
952
+ auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch );
947
953
948
954
auto & cur = inp->cls ;
949
955
0 commit comments