@@ -2857,20 +2857,44 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2857
2857
}
2858
2858
2859
2859
#define SIN_COS_N_COUNT WHISPER_N_FFT
2860
- static float sin_vals[SIN_COS_N_COUNT];
2861
- static float cos_vals[SIN_COS_N_COUNT];
2860
+ namespace {
2861
+ struct whisper_global_cache {
2862
+ // In FFT, we frequently use sine and cosine operations with the same values.
2863
+ // We can use precalculated values to speed up the process.
2864
+ float sin_vals[SIN_COS_N_COUNT];
2865
+ float cos_vals[SIN_COS_N_COUNT];
2866
+
2867
+ // Hann window (Use cosf to eliminate difference)
2868
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2869
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2870
+ float hann_window[WHISPER_N_FFT];
2871
+ float hann_window2x[WHISPER_N_FFT * 2 ];
2872
+
2873
+ whisper_global_cache () {
2874
+ fill_sin_cos_table ();
2875
+ #define FILL_HANN_WINDOW (arr ) fill_hann_window(sizeof (arr) / sizeof (arr[0 ]), true , arr)
2876
+ FILL_HANN_WINDOW (hann_window);
2877
+ FILL_HANN_WINDOW (hann_window2x);
2878
+ }
2879
+
2880
+ void fill_sin_cos_table () {
2881
+ for (int i = 0 ; i < SIN_COS_N_COUNT; i++) {
2882
+ double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
2883
+ sin_vals[i] = sinf (theta);
2884
+ cos_vals[i] = cosf (theta);
2885
+ }
2886
+ }
2862
2887
2863
- // In FFT, we frequently use sine and cosine operations with the same values.
2864
- // We can use precalculated values to speed up the process.
2865
- static void fill_sin_cos_table () {
2866
- static bool is_filled = false ;
2867
- if (is_filled) return ;
2868
- for (int i = 0 ; i < SIN_COS_N_COUNT; i++) {
2869
- double theta = (2 *M_PI*i)/SIN_COS_N_COUNT;
2870
- sin_vals[i] = sinf (theta);
2871
- cos_vals[i] = cosf (theta);
2888
+ void fill_hann_window (int length, bool periodic, float * output) {
2889
+ int offset = -1 ;
2890
+ if (periodic) {
2891
+ offset = 0 ;
2892
+ }
2893
+ for (int i = 0 ; i < length; i++) {
2894
+ output[i] = 0.5 * (1.0 - cosf ((2.0 * M_PI * i) / (length + offset)));
2895
+ }
2872
2896
}
2873
- is_filled = true ;
2897
+ } global_cache ;
2874
2898
}
2875
2899
2876
2900
// naive Discrete Fourier Transform
@@ -2888,8 +2912,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2888
2912
2889
2913
for (int n = 0 ; n < N; n++) {
2890
2914
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2891
- re += in[n]*cos_vals[idx]; // cos(t)
2892
- im -= in[n]*sin_vals[idx]; // sin(t)
2915
+ re += in[n]*global_cache. cos_vals [idx]; // cos(t)
2916
+ im -= in[n]*global_cache. sin_vals [idx]; // sin(t)
2893
2917
}
2894
2918
2895
2919
out[k*2 + 0 ] = re;
@@ -2940,8 +2964,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2940
2964
const int sin_cos_step = SIN_COS_N_COUNT / N;
2941
2965
for (int k = 0 ; k < N/2 ; k++) {
2942
2966
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2943
- float re = cos_vals[idx]; // cos(t)
2944
- float im = -sin_vals[idx]; // sin(t)
2967
+ float re = global_cache. cos_vals [idx]; // cos(t)
2968
+ float im = -global_cache. sin_vals [idx]; // sin(t)
2945
2969
2946
2970
float re_odd = odd_fft[2 *k + 0 ];
2947
2971
float im_odd = odd_fft[2 *k + 1 ];
@@ -2954,22 +2978,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2954
2978
}
2955
2979
}
2956
2980
2957
- static bool hann_window (int length, bool periodic, std::vector<float > & output) {
2958
- if (output.size () < static_cast <size_t >(length)) {
2959
- output.resize (length);
2960
- }
2961
- int offset = -1 ;
2962
- if (periodic) {
2963
- offset = 0 ;
2964
- }
2965
- for (int i = 0 ; i < length; i++) {
2966
- output[i] = 0.5 *(1.0 - cosf ((2.0 *M_PI*i)/(length + offset)));
2967
- }
2968
-
2969
- return true ;
2970
- }
2971
-
2972
- static void log_mel_spectrogram_worker_thread (int ith, const std::vector<float > & hann, const std::vector<float > & samples,
2981
+ static void log_mel_spectrogram_worker_thread (int ith, const float * hann, const std::vector<float > & samples,
2973
2982
int n_samples, int frame_size, int frame_step, int n_threads,
2974
2983
const whisper_filters & filters, whisper_mel & mel) {
2975
2984
std::vector<float > fft_in (frame_size, 0.0 );
@@ -2984,7 +2993,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2984
2993
for (; i < std::min (n_samples / frame_step + 1 , mel.n_len ); i += n_threads) {
2985
2994
const int offset = i * frame_step;
2986
2995
2987
- // apply Hanning window (~10% faster)
2996
+ // apply Hann window (~10% faster)
2988
2997
for (int j = 0 ; j < std::min (frame_size, n_samples - offset); j++) {
2989
2998
fft_in[j] = hann[j] * samples[offset + j];
2990
2999
}
@@ -3051,12 +3060,16 @@ static bool log_mel_spectrogram(
3051
3060
whisper_mel & mel) {
3052
3061
const int64_t t_start_us = ggml_time_us ();
3053
3062
3054
- // Hanning window (Use cosf to eliminate difference)
3055
- // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
3056
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
3057
- std::vector<float > hann;
3058
- hann_window (frame_size, true , hann);
3059
-
3063
+ // Hann window
3064
+ const float * hann = nullptr ;
3065
+ if (frame_size == WHISPER_N_FFT) {
3066
+ hann = global_cache.hann_window ;
3067
+ } else if (frame_size == 2 * WHISPER_N_FFT) {
3068
+ hann = global_cache.hann_window2x ;
3069
+ } else {
3070
+ WHISPER_ASSERT (false && " Unsupported frame_size" );
3071
+ return false ;
3072
+ }
3060
3073
3061
3074
// Calculate the length of padding
3062
3075
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30 ;
@@ -3086,7 +3099,7 @@ static bool log_mel_spectrogram(
3086
3099
std::vector<std::thread> workers (n_threads - 1 );
3087
3100
for (int iw = 0 ; iw < n_threads - 1 ; ++iw) {
3088
3101
workers[iw] = std::thread (
3089
- log_mel_spectrogram_worker_thread, iw + 1 , std::cref ( hann) , samples_padded,
3102
+ log_mel_spectrogram_worker_thread, iw + 1 , hann, samples_padded,
3090
3103
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
3091
3104
std::cref (filters), std::ref (mel));
3092
3105
}
@@ -3246,8 +3259,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
3246
3259
#endif
3247
3260
3248
3261
struct whisper_state * whisper_init_state (whisper_context * ctx) {
3249
- fill_sin_cos_table ();
3250
-
3251
3262
whisper_state * state = new whisper_state;
3252
3263
3253
3264
state->backend = whisper_backend_init (ctx->params );
@@ -7235,7 +7246,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7235
7246
// operation (after median filter)
7236
7247
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7237
7248
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7238
- w = ggml_norm (gctx, w, 1e-9 );
7249
+ w = ggml_norm (gctx, w, 1e-9f );
7239
7250
w = ggml_permute (gctx, ggml_permute (gctx, w, 2 , 1 , 0 ,3 ), 0 , 2 , 1 , 3 );
7240
7251
7241
7252
// Pass median filter - this is done over AUDIO_TOKENS dimension.
0 commit comments