Skip to content

Smooth Sampling / Quadratic Sampling support #6445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ static void sampler_queue(
const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent;
const float smoothing_factor = params.smoothing_factor;
const float smoothing_curve = params.smoothing_curve;
const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
Expand All @@ -161,10 +163,10 @@ static void sampler_queue(
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
if (dynatemp_range > 0 || smoothing_factor > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent, smoothing_factor, smoothing_curve);
} else {
llama_sample_temp(ctx_main, &cur_p, temp);
}
Expand Down
2 changes: 2 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ typedef struct llama_sampling_params {
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
float smoothing_factor = 0.0f; // controls the quadratic adjustment in smooth sampling
float smoothing_curve = 1.0f; // controls the quadratic adjustment in smooth sampling
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
Expand Down
2 changes: 2 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ struct server_context {
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.smoothing_factor = json_value(data, "smoothing_factor", default_sparams.smoothing_factor);
slot.sparams.smoothing_curve = json_value(data, "smoothing_curve", default_sparams.smoothing_curve);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
Expand Down
19 changes: 17 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13510,14 +13510,29 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
}
}

void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val, float smoothing_factor, float smoothing_curve) {
const int64_t t_start_sample_us = ggml_time_us();

// no need to do anything if there is only one (or zero) candidates
if(candidates_p->size <= 1) {
if (candidates_p->size <= 1) {
return;
}

// Apply smoothing if smoothing_factor is > 0. Do not change base implementation otherwise.
if (smoothing_factor > 0 && candidates_p->size > 1) {
llama_sample_softmax(ctx, candidates_p);
float h = candidates_p->data[0].logit; // Find the maximum logit for h to be added after the transformation

// Apply the modified quadratic transformation using the smoothing_factor and smoothing_curve
for (size_t i = 0; i < candidates_p->size; ++i) {
float logit_shifted = candidates_p->data[i].logit - h;
float k = (3 - smoothing_curve) / 2;
float s = (smoothing_curve - 1) / 2;
candidates_p->data[i].logit = -(k * smoothing_factor * logit_shifted * logit_shifted) + (s * smoothing_factor * logit_shifted * logit_shifted * logit_shifted) + h;
}
llama_sample_softmax(ctx, candidates_p);
}

// Calculate maximum possible entropy
float max_entropy = -logf(1.0f / candidates_p->size);

Expand Down
6 changes: 4 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -963,13 +963,15 @@ extern "C" {
float p,
size_t min_keep);

/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
/// @details Dynamic temperature implementation + Smooth Sampling implementations wrapped into one function, no research papers available
LLAMA_API void llama_sample_entropy(
struct llama_context * ctx,
llama_token_data_array * candidates_p,
float min_temp,
float max_temp,
float exponent_val);
float exponent_val,
float smoothing_factor,
float smoothing_curve);

LLAMA_API void llama_sample_temp(
struct llama_context * ctx,
Expand Down
Loading