Skip to content

Commit d12b4d2

Browse files
author
panhehe
committed
[runtime/xpu] Support the execution of non-streaming parsing on the Kunlun XPU card #1455
1 parent 89e8d0d commit d12b4d2

28 files changed

+3385
-6
lines changed

runtime/core/cmake/xpu.cmake

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
if(NOT WIN32)
2+
string(ASCII 27 Esc)
3+
set(ColourReset "${Esc}[m")
4+
set(ColourBold "${Esc}[1m")
5+
set(Red "${Esc}[31m")
6+
set(Green "${Esc}[32m")
7+
set(Yellow "${Esc}[33m")
8+
set(Blue "${Esc}[34m")
9+
set(Magenta "${Esc}[35m")
10+
set(Cyan "${Esc}[36m")
11+
set(White "${Esc}[37m")
12+
set(BoldRed "${Esc}[1;31m")
13+
set(BoldGreen "${Esc}[1;32m")
14+
set(BoldYellow "${Esc}[1;33m")
15+
set(BoldBlue "${Esc}[1;34m")
16+
set(BoldMagenta "${Esc}[1;35m")
17+
set(BoldCyan "${Esc}[1;36m")
18+
set(BoldWhite "${Esc}[1;37m")
19+
endif()
20+
21+
if(XPU)
22+
set(RUNTIME_XPU_PATH ${CMAKE_CURRENT_SOURCE_DIR})
23+
message(STATUS "RUNTIME_XPU_PATH is ${RUNTIME_XPU_PATH} .\n")
24+
set(XPU_KUNLUN_PATH ${RUNTIME_XPU_PATH}/decoder/xpu_kunlun)
25+
if(NOT DEFINED ENV{XPU_API_PATH})
26+
message(FATAL_ERROR "${BoldRed}NO ENV{XPU_API_PATH} in your env. Please set XPU_API_PATH.${ColourReset}\n")
27+
else()
28+
set(XPU_API_PATH $ENV{XPU_API_PATH})
29+
message("set XPU_API_PATH from env_var. Val is $ENV{XPU_API_PATH}.")
30+
endif()
31+
32+
include_directories(${XPU_KUNLUN_PATH}/
33+
${XPU_API_PATH}/output/include ${XPU_API_PATH}/../runtime/include)
34+
link_directories(${XPU_API_PATH}/output/so/ ${XPU_API_PATH}/../runtime/output/so/)
35+
36+
add_definitions(-DUSE_XPU)
37+
endif()

runtime/core/decoder/CMakeLists.txt

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ set(decoder_srcs
77
ctc_endpoint.cc
88
)
99

10-
if(NOT TORCH AND NOT ONNX)
11-
message(FATAL_ERROR "Please build with TORCH or ONNX!!!")
10+
if(NOT TORCH AND NOT ONNX AND NOT XPU)
11+
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU!!!")
1212
endif()
1313
if(TORCH)
1414
list(APPEND decoder_srcs torch_asr_model.cc)
@@ -17,8 +17,23 @@ if(ONNX)
1717
list(APPEND decoder_srcs onnx_asr_model.cc)
1818
endif()
1919

20+
if(XPU)
21+
list(APPEND decoder_srcs xpu_asr_model.cc)
22+
list(APPEND decoder_srcs ./xpu_kunlun/xpu_conformer.cpp)
23+
list(APPEND decoder_srcs ./xpu_kunlun/xpu_util.cpp)
24+
message(STATUS "xpu decoder_srcs is :: ${decoder_srcs} \n")
25+
# compile conformer_test
26+
add_subdirectory(xpu_kunlun)
27+
endif()
28+
2029
add_library(decoder STATIC ${decoder_srcs})
21-
target_link_libraries(decoder PUBLIC kaldi-decoder frontend post_processor utils)
30+
if(XPU)
31+
target_link_libraries(decoder PUBLIC kaldi-decoder frontend
32+
post_processor utils xpuapi xpurt)
33+
else()
34+
target_link_libraries(decoder PUBLIC kaldi-decoder frontend
35+
post_processor utils)
36+
endif()
2237

2338
if(ANDROID)
2439
target_link_libraries(decoder PUBLIC ${PYTORCH_LIBRARY} ${FBJNI_LIBRARY})

runtime/core/decoder/params.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16-
1716
#ifndef DECODER_PARAMS_H_
1817
#define DECODER_PARAMS_H_
1918

@@ -29,17 +28,24 @@
2928
#ifdef USE_TORCH
3029
#include "decoder/torch_asr_model.h"
3130
#endif
31+
#ifdef USE_XPU
32+
#include "decoder/xpu_asr_model.h"
33+
#endif
3234
#include "frontend/feature_pipeline.h"
3335
#include "post_processor/post_processor.h"
3436
#include "utils/flags.h"
3537
#include "utils/string.h"
3638

3739
DEFINE_int32(num_threads, 1, "num threads for ASR model");
40+
DEFINE_int32(device_id, 0, "set XPU DeviceID for ASR model");
3841

3942
// TorchAsrModel flags
4043
DEFINE_string(model_path, "", "pytorch exported model path");
4144
// OnnxAsrModel flags
4245
DEFINE_string(onnx_dir, "", "directory where the onnx model is saved");
46+
// XPUAsrModel flags
47+
DEFINE_string(xpu_model_dir, "",
48+
"directory where the XPU model and weights is saved");
4349

4450
// FeaturePipelineConfig flags
4551
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
@@ -66,7 +72,8 @@ DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
6672
DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search");
6773
DEFINE_double(blank_skip_thresh, 1.0,
6874
"blank skip thresh for ctc wfst search, 1.0 means no skip");
69-
DEFINE_double(length_penalty, 0.0, "length penalty ctc wfst search, will not"
75+
DEFINE_double(length_penalty, 0.0,
76+
"length penalty ctc wfst search, will not"
7077
"apply on self-loop arc, for balancing the del/ins ratio, "
7178
"suggest set to -3.0");
7279
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
@@ -130,7 +137,7 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
130137
#else
131138
LOG(FATAL) << "Please rebuild with cmake options '-DONNX=ON'.";
132139
#endif
133-
} else {
140+
} else if (!FLAGS_model_path.empty()) {
134141
#ifdef USE_TORCH
135142
LOG(INFO) << "Reading torch model " << FLAGS_model_path;
136143
TorchAsrModel::InitEngineThreads(FLAGS_num_threads);
@@ -140,6 +147,19 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
140147
#else
141148
LOG(FATAL) << "Please rebuild with cmake options '-DTORCH=ON'.";
142149
#endif
150+
} else if (!FLAGS_xpu_model_dir.empty()) {
151+
#ifdef USE_XPU
152+
LOG(INFO) << "Reading XPU WeNet model weight from " << FLAGS_xpu_model_dir;
153+
auto model = std::make_shared<XPUAsrModel>();
154+
model->SetEngineThreads(FLAGS_num_threads);
155+
model->SetDeviceId(FLAGS_device_id);
156+
model->Read(FLAGS_xpu_model_dir);
157+
resource->model = model;
158+
#else
159+
LOG(FATAL) << "Please rebuild with cmake options '-DXPU=ON'.";
160+
#endif
161+
} else {
162+
LOG(FATAL) << "Please set ONNX, TORCH or XPU model path!!!";
143163
}
144164

145165
LOG(INFO) << "Reading unit table " << FLAGS_unit_path;
@@ -186,6 +206,7 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
186206
post_process_opts.lowercase = FLAGS_lowercase;
187207
resource->post_processor =
188208
std::make_shared<PostProcessor>(std::move(post_process_opts));
209+
LOG(INFO) << "Finish set PostProcessOptions. \n";
189210
return resource;
190211
}
191212

0 commit comments

Comments
 (0)