Skip to content

Commit 47ab304

Browse files
iboBiThalay
authored andcommitted
whisper: use global cache for sin/cos vals and Hann window (ggml-org#2194)
- also rename Hanning to Hann as it's named after Julius von Hann as per Wikipedia
1 parent 2f9b67a commit 47ab304

File tree

1 file changed

+54
-43
lines changed

1 file changed

+54
-43
lines changed

whisper.cpp

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2857,20 +2857,44 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
28572857
}
28582858

28592859
#define SIN_COS_N_COUNT WHISPER_N_FFT
2860-
static float sin_vals[SIN_COS_N_COUNT];
2861-
static float cos_vals[SIN_COS_N_COUNT];
2860+
namespace {
2861+
struct whisper_global_cache {
2862+
// In FFT, we frequently use sine and cosine operations with the same values.
2863+
// We can use precalculated values to speed up the process.
2864+
float sin_vals[SIN_COS_N_COUNT];
2865+
float cos_vals[SIN_COS_N_COUNT];
2866+
2867+
// Hann window (Use cosf to eliminate difference)
2868+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2869+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2870+
float hann_window[WHISPER_N_FFT];
2871+
float hann_window2x[WHISPER_N_FFT * 2];
2872+
2873+
whisper_global_cache() {
2874+
fill_sin_cos_table();
2875+
#define FILL_HANN_WINDOW(arr) fill_hann_window(sizeof(arr) / sizeof(arr[0]), true, arr)
2876+
FILL_HANN_WINDOW(hann_window);
2877+
FILL_HANN_WINDOW(hann_window2x);
2878+
}
2879+
2880+
void fill_sin_cos_table() {
2881+
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2882+
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
2883+
sin_vals[i] = sinf(theta);
2884+
cos_vals[i] = cosf(theta);
2885+
}
2886+
}
28622887

2863-
// In FFT, we frequently use sine and cosine operations with the same values.
2864-
// We can use precalculated values to speed up the process.
2865-
static void fill_sin_cos_table() {
2866-
static bool is_filled = false;
2867-
if (is_filled) return;
2868-
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2869-
double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
2870-
sin_vals[i] = sinf(theta);
2871-
cos_vals[i] = cosf(theta);
2888+
void fill_hann_window(int length, bool periodic, float* output) {
2889+
int offset = -1;
2890+
if (periodic) {
2891+
offset = 0;
2892+
}
2893+
for (int i = 0; i < length; i++) {
2894+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
2895+
}
28722896
}
2873-
is_filled = true;
2897+
} global_cache;
28742898
}
28752899

28762900
// naive Discrete Fourier Transform
@@ -2888,8 +2912,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
28882912

28892913
for (int n = 0; n < N; n++) {
28902914
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2891-
re += in[n]*cos_vals[idx]; // cos(t)
2892-
im -= in[n]*sin_vals[idx]; // sin(t)
2915+
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
2916+
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
28932917
}
28942918

28952919
out[k*2 + 0] = re;
@@ -2940,8 +2964,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
29402964
const int sin_cos_step = SIN_COS_N_COUNT / N;
29412965
for (int k = 0; k < N/2; k++) {
29422966
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2943-
float re = cos_vals[idx]; // cos(t)
2944-
float im = -sin_vals[idx]; // sin(t)
2967+
float re = global_cache.cos_vals[idx]; // cos(t)
2968+
float im = -global_cache.sin_vals[idx]; // sin(t)
29452969

29462970
float re_odd = odd_fft[2*k + 0];
29472971
float im_odd = odd_fft[2*k + 1];
@@ -2954,22 +2978,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
29542978
}
29552979
}
29562980

2957-
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2958-
if (output.size() < static_cast<size_t>(length)) {
2959-
output.resize(length);
2960-
}
2961-
int offset = -1;
2962-
if (periodic) {
2963-
offset = 0;
2964-
}
2965-
for (int i = 0; i < length; i++) {
2966-
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
2967-
}
2968-
2969-
return true;
2970-
}
2971-
2972-
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
2981+
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
29732982
int n_samples, int frame_size, int frame_step, int n_threads,
29742983
const whisper_filters & filters, whisper_mel & mel) {
29752984
std::vector<float> fft_in(frame_size, 0.0);
@@ -2984,7 +2993,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
29842993
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
29852994
const int offset = i * frame_step;
29862995

2987-
// apply Hanning window (~10% faster)
2996+
// apply Hann window (~10% faster)
29882997
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
29892998
fft_in[j] = hann[j] * samples[offset + j];
29902999
}
@@ -3051,12 +3060,16 @@ static bool log_mel_spectrogram(
30513060
whisper_mel & mel) {
30523061
const int64_t t_start_us = ggml_time_us();
30533062

3054-
// Hanning window (Use cosf to eliminate difference)
3055-
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
3056-
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
3057-
std::vector<float> hann;
3058-
hann_window(frame_size, true, hann);
3059-
3063+
// Hann window
3064+
const float * hann = nullptr;
3065+
if (frame_size == WHISPER_N_FFT) {
3066+
hann = global_cache.hann_window;
3067+
} else if (frame_size == 2 * WHISPER_N_FFT) {
3068+
hann = global_cache.hann_window2x;
3069+
} else {
3070+
WHISPER_ASSERT(false && "Unsupported frame_size");
3071+
return false;
3072+
}
30603073

30613074
// Calculate the length of padding
30623075
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
@@ -3086,7 +3099,7 @@ static bool log_mel_spectrogram(
30863099
std::vector<std::thread> workers(n_threads - 1);
30873100
for (int iw = 0; iw < n_threads - 1; ++iw) {
30883101
workers[iw] = std::thread(
3089-
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
3102+
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
30903103
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
30913104
std::cref(filters), std::ref(mel));
30923105
}
@@ -3246,8 +3259,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
32463259
#endif
32473260

32483261
struct whisper_state * whisper_init_state(whisper_context * ctx) {
3249-
fill_sin_cos_table();
3250-
32513262
whisper_state * state = new whisper_state;
32523263

32533264
state->backend = whisper_backend_init(ctx->params);
@@ -7235,7 +7246,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
72357246
// operation (after median filter)
72367247
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
72377248
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7238-
w = ggml_norm(gctx, w, 1e-9);
7249+
w = ggml_norm(gctx, w, 1e-9f);
72397250
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
72407251

72417252
// Pass median filter - this is done over AUDIO_TOKENS dimension.

0 commit comments

Comments
 (0)