Skip to content

Implement no_speech_thold #2625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ extern "C" {
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
float no_speech_thold;

struct {
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
Expand Down
40 changes: 37 additions & 3 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ struct whisper_state {
whisper_token tid_last;

std::vector<float> energy; // PCM signal energy
float no_speech_prob = 0.0f;

// [EXPERIMENTAL] Token-level timestamps with DTW
whisper_aheads_masks aheads_masks;
Expand Down Expand Up @@ -5647,6 +5648,35 @@ int whisper_full_with_state(
return -8;
}

// Calculate no_speech probability after first decode
{
const float * logits = state->logits.data();
const int n_vocab = ctx->vocab.n_vocab;

// Find max element for numerical stability
float max_logit = -INFINITY;
for (int i = 0; i < n_vocab; ++i) {
max_logit = std::max(max_logit, logits[i]);
}

// Calculate softmax
float sum_exp = 0.0f;
std::vector<float> probs(n_vocab);
for (int i = 0; i < n_vocab; ++i) {
float exp_val = expf(logits[i] - max_logit);
sum_exp += exp_val;
probs[i] = exp_val;
}

// Normalize
for (int i = 0; i < n_vocab; ++i) {
probs[i] /= sum_exp;
}

// Get probability of no_speech token
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This likely has to be done inside whisper_process_logits in order to avoid computing the softmax again just for this probability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we cannot reuse the softmax computed inside whisper_process_logits since no_speech_prob has to be calculated before any logits filtering. Otherwise we get some wrong no_speech_prob values. The same method is followed in openai's whisper as well. https://github.com/openai/whisper/blob/main/whisper/decoding.py#L689-L703
Since this no_speech_prob calculation is only for the first token in the sequence, it will not cause a big performance impact.
On a related note, I have now modularized the probs calculation and now reusing the same code as whisper_process_logits

{
const int64_t t_start_sample_us = ggml_time_us();

Expand Down Expand Up @@ -6038,7 +6068,8 @@ int whisper_full_with_state(
if (it != (int) temperatures.size() - 1) {
const auto & decoder = state->decoders[best_decoder_id];

if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
if (decoder.failed ||
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The log message is no longer correct
  • The comparison for the speech prob is wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Thanks. I have corrected the log message to print the no_speech_prob and no_speech_thold values as well.
  • The comparison logic is on par with the openai implementation. avg_logprobs being lesser than the threshold is considered as a failure only for speech segment. If it is a non-speech, then it is considered as a successful prediction of "silence". Here is the relavant code from openai. I have just merged it into one condition by inverting the comparison. But the logic is the same.
    https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L209-L220

WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
success = false;
state->n_fail_p++;
Expand Down Expand Up @@ -6068,6 +6099,9 @@ int whisper_full_with_state(
// [EXPERIMENTAL] Token-level timestamps with DTW
const auto n_segments_before = state->result_all.size();

const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
best_decoder.sequence.avg_logprobs < params.logprob_thold);

//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);

// update prompt_past
Expand All @@ -6076,11 +6110,11 @@ int whisper_full_with_state(
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
}

for (int i = 0; i < result_len; ++i) {
for (int i = 0; i < result_len && !is_no_speech; ++i) {
prompt_past.push_back(tokens_cur[i].id);
}

if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));

Expand Down