Skip to content

Commit fbae8dc

Browse files
committed
whisper : use backend registry
1 parent c800966 commit fbae8dc

File tree

1 file changed

+60
-123
lines changed

1 file changed

+60
-123
lines changed

src/whisper.cpp

Lines changed: 60 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,19 @@
11
#include "whisper.h"
22

3-
#ifdef WHISPER_USE_COREML
4-
#include "coreml/whisper-encoder.h"
5-
#endif
6-
73
#include "ggml-cpu.h"
84

9-
#ifdef GGML_USE_METAL
10-
#include "ggml-metal.h"
11-
#endif
12-
13-
#ifdef GGML_USE_CUDA
14-
#include "ggml-cuda.h"
15-
#endif
16-
17-
#ifdef GGML_USE_SYCL
18-
#include "ggml-sycl.h"
19-
#endif
20-
21-
#ifdef GGML_USE_VULKAN
22-
#include "ggml-vulkan.h"
23-
#endif
5+
#include "ggml.h"
6+
#include "ggml-alloc.h"
7+
#include "ggml-backend.h"
248

25-
#ifdef GGML_USE_BLAS
26-
#include "ggml-blas.h"
9+
#ifdef WHISPER_USE_COREML
10+
#include "coreml/whisper-encoder.h"
2711
#endif
2812

2913
#ifdef WHISPER_USE_OPENVINO
3014
#include "openvino/whisper-openvino-encoder.h"
3115
#endif
3216

33-
#ifdef GGML_USE_CANN
34-
#include "ggml-cann.h"
35-
#endif
36-
37-
#include "ggml.h"
38-
#include "ggml-alloc.h"
39-
#include "ggml-backend.h"
40-
4117
#include <atomic>
4218
#include <algorithm>
4319
#include <cassert>
@@ -195,14 +171,13 @@ static bool ggml_graph_compute_helper(
195171

196172
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
197173
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
198-
if (ggml_backend_is_cpu(backend)) {
199-
ggml_backend_cpu_set_n_threads(backend, n_threads);
200-
}
201-
#ifdef GGML_USE_BLAS
202-
if (ggml_backend_is_blas(backend)) {
203-
ggml_backend_blas_set_n_threads(backend, n_threads);
174+
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
175+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
176+
177+
auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
178+
if (fn_set_n_threads) {
179+
fn_set_n_threads(backend, n_threads);
204180
}
205-
#endif
206181
}
207182

208183
bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
@@ -1260,61 +1235,24 @@ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & pa
12601235

12611236
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
12621237

1263-
#ifdef GGML_USE_CUDA
1264-
if (params.use_gpu) {
1265-
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1266-
result = ggml_backend_cuda_init(params.gpu_device);
1267-
if (!result) {
1268-
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
1269-
}
1270-
}
1271-
#endif
1272-
1273-
#ifdef GGML_USE_METAL
1274-
if (params.use_gpu) {
1275-
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1276-
result = ggml_backend_metal_init();
1277-
if (!result) {
1278-
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
1279-
} else if (!ggml_backend_metal_supports_family(result, 7)) {
1280-
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1281-
ggml_backend_free(result);
1282-
result = NULL;
1283-
}
1284-
}
1285-
#endif
1286-
1287-
#ifdef GGML_USE_SYCL
1288-
if (params.use_gpu) {
1289-
WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1290-
result = ggml_backend_sycl_init(params.gpu_device);
1291-
if (!result) {
1292-
WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
1293-
}
1294-
}
1295-
#endif
1296-
1297-
#ifdef GGML_USE_VULKAN
12981238
if (params.use_gpu) {
1299-
WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__);
1300-
result = ggml_backend_vk_init(params.gpu_device);
1301-
if (!result) {
1302-
WHISPER_LOG_ERROR("%s: ggml_backend_vk_init() failed\n", __func__);
1303-
}
1304-
}
1305-
#endif
1306-
1307-
#ifdef GGML_USE_CANN
1308-
if (params.use_gpu) {
1309-
WHISPER_LOG_INFO("%s: using CANN backend\n", __func__);
1310-
result = ggml_backend_cann_init(params.gpu_device);
1311-
if (!result) {
1312-
WHISPER_LOG_ERROR("%s: ggml_backend_cann_init() failed\n", __func__);
1239+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1240+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1241+
switch (ggml_backend_dev_type(dev)) {
1242+
case GGML_BACKEND_DEVICE_TYPE_CPU:
1243+
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
1244+
// skip CPU backends
1245+
break;
1246+
case GGML_BACKEND_DEVICE_TYPE_GPU:
1247+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1248+
result = ggml_backend_dev_init(dev, nullptr);
1249+
if (!result) {
1250+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1251+
}
1252+
break;
1253+
}
13131254
}
13141255
}
1315-
#endif
1316-
1317-
GGML_UNUSED(params);
13181256

13191257
return result;
13201258
}
@@ -1328,17 +1266,19 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
13281266
result.push_back(backend_gpu);
13291267
}
13301268

1331-
#ifdef GGML_USE_BLAS
1332-
{
1333-
WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
1334-
ggml_backend_t backend_blas = ggml_backend_blas_init();
1335-
if (!backend_blas) {
1336-
WHISPER_LOG_ERROR("%s: ggml_backend_blas_init() failed\n", __func__);
1337-
} else {
1338-
result.push_back(backend_blas);
1269+
// ACCEL backends
1270+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1271+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1272+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
1273+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1274+
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
1275+
if (!backend) {
1276+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1277+
continue;
1278+
}
1279+
result.push_back(backend);
13391280
}
13401281
}
1341-
#endif
13421282

13431283
GGML_UNUSED(params);
13441284

@@ -1348,33 +1288,26 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
13481288
}
13491289

13501290
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1351-
ggml_backend_buffer_type_t result = nullptr;
1352-
1353-
params.use_gpu || (result = ggml_backend_cpu_buffer_type());
1354-
1355-
#ifdef GGML_USE_CUDA
1356-
result || (result = ggml_backend_cuda_buffer_type(params.gpu_device));
1357-
#endif
1358-
1359-
#ifdef GGML_USE_METAL
1360-
result || (result = ggml_backend_metal_buffer_type());
1361-
#endif
1362-
1363-
#ifdef GGML_USE_SYCL
1364-
result || (result = ggml_backend_sycl_buffer_type(params.gpu_device));
1365-
#endif
1366-
1367-
#ifdef GGML_USE_VULKAN
1368-
result || (result = ggml_backend_vk_buffer_type(params.gpu_device));
1369-
#endif
1291+
if (!params.use_gpu) {
1292+
return ggml_backend_cpu_buffer_type();
1293+
}
13701294

1371-
#ifdef GGML_USE_CANN
1372-
result || (result == ggml_backend_cann_buffer_type(params.gpu_device));
1373-
#endif
1295+
// if we have a GPU device - use it
1296+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1297+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1298+
switch (ggml_backend_dev_type(dev)) {
1299+
case GGML_BACKEND_DEVICE_TYPE_CPU:
1300+
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
1301+
// skip CPU backends
1302+
break;
13741303

1375-
result || (result = ggml_backend_cpu_buffer_type());
1304+
case GGML_BACKEND_DEVICE_TYPE_GPU:
1305+
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
1306+
return ggml_backend_dev_buffer_type(dev);
1307+
}
1308+
}
13761309

1377-
return result;
1310+
return ggml_backend_cpu_buffer_type();
13781311
}
13791312

13801313
// load the model from a ggml file
@@ -3668,8 +3601,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
36683601
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
36693602
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
36703603
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3671-
3672-
// TODO: temporary call to force backend registry initialization
3604+
WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
36733605
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
36743606

36753607
whisper_context * ctx = new whisper_context;
@@ -7427,6 +7359,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
74277359
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
74287360
(void) level;
74297361
(void) user_data;
7362+
#ifndef WHISPER_DEBUG
7363+
if (level == GGML_LOG_LEVEL_DEBUG) {
7364+
return;
7365+
}
7366+
#endif
74307367
fputs(text, stderr);
74317368
fflush(stderr);
74327369
}

0 commit comments

Comments
 (0)