From 9f0c7cdb8f78ac372fd29bfecd2c32f06fdf16d3 Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Thu, 6 Apr 2023 17:05:38 +0300 Subject: [PATCH 1/2] Always sort logits before nucleus sampling --- llama.cpp | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/llama.cpp b/llama.cpp index 581a8399d0229..d8c8f474c9731 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k( } } - if (top_k > 0 && top_k < n_logits) { - sample_top_k(logits_id, top_k); - } - - float maxl = -std::numeric_limits::infinity(); - for (const auto & kv : logits_id) { - maxl = Max(maxl, kv.first); - } + sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits); // compute probs for the top k tokens std::vector probs; probs.reserve(logits_id.size()); + float maxl = logits_id[0].first; double sum = 0.0; for (const auto & kv : logits_id) { const float p = expf(kv.first - maxl); @@ -1272,15 +1266,14 @@ static llama_vocab::id llama_sample_top_p_top_k( } } - cumsum = 1.0/cumsum; for (int i = 0; i < (int) probs.size(); i++) { - probs[i] *= cumsum; + probs[i] /= cumsum; } } //printf("\n"); //for (int i = 0; i < (int) 10; i++) { - // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); + // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]); //} //printf("\n\n"); //exit(0); From 2ceeccf8dcb2632d8cd3a87fb4515ea7ef71b421 Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Thu, 6 Apr 2023 18:03:37 +0300 Subject: [PATCH 2/2] remove second normalization - fix windows build - remove normalization since std::discrete_distribution does not require it --- llama.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index d8c8f474c9731..978327a5b50d1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1236,7 +1236,7 @@ static llama_vocab::id llama_sample_top_p_top_k( } } - sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits); + sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits); // compute probs for the top k tokens std::vector probs; @@ -1265,10 +1265,6 @@ static llama_vocab::id llama_sample_top_p_top_k( break; } } - - for (int i = 0; i < (int) probs.size(); i++) { - probs[i] /= cumsum; - } } //printf("\n");