13
13
// See the License for the specific language governing permissions and
14
14
// limitations under the License.
15
15
16
-
17
16
#ifndef DECODER_PARAMS_H_
18
17
#define DECODER_PARAMS_H_
19
18
29
28
#ifdef USE_TORCH
30
29
#include " decoder/torch_asr_model.h"
31
30
#endif
31
+ #ifdef USE_XPU
32
+ #include " decoder/xpu_asr_model.h"
33
+ #endif
32
34
#include " frontend/feature_pipeline.h"
33
35
#include " post_processor/post_processor.h"
34
36
#include " utils/flags.h"
35
37
#include " utils/string.h"
36
38
37
39
DEFINE_int32 (num_threads, 1 , " num threads for ASR model" );
40
+ DEFINE_int32 (device_id, 0 , " set XPU DeviceID for ASR model" );
38
41
39
42
// TorchAsrModel flags
40
43
DEFINE_string (model_path, " " , " pytorch exported model path" );
41
44
// OnnxAsrModel flags
42
45
DEFINE_string (onnx_dir, " " , " directory where the onnx model is saved" );
46
+ // XPUAsrModel flags
47
+ DEFINE_string (xpu_model_dir, " " ,
48
+ " directory where the XPU model and weights is saved" );
43
49
44
50
// FeaturePipelineConfig flags
45
51
DEFINE_int32 (num_bins, 80 , " num mel bins for fbank feature" );
@@ -66,7 +72,8 @@ DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
66
72
DEFINE_double (acoustic_scale, 1.0 , " acoustic scale for ctc wfst search" );
67
73
DEFINE_double (blank_skip_thresh, 1.0 ,
68
74
" blank skip thresh for ctc wfst search, 1.0 means no skip" );
69
- DEFINE_double (length_penalty, 0.0 , " length penalty ctc wfst search, will not"
75
+ DEFINE_double (length_penalty, 0.0 ,
76
+ " length penalty ctc wfst search, will not"
70
77
" apply on self-loop arc, for balancing the del/ins ratio, "
71
78
" suggest set to -3.0" );
72
79
DEFINE_int32 (nbest, 10 , " nbest for ctc wfst or prefix search" );
@@ -130,7 +137,7 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
130
137
#else
131
138
LOG (FATAL) << " Please rebuild with cmake options '-DONNX=ON'." ;
132
139
#endif
133
- } else {
140
+ } else if (!FLAGS_model_path. empty ()) {
134
141
#ifdef USE_TORCH
135
142
LOG (INFO) << " Reading torch model " << FLAGS_model_path;
136
143
TorchAsrModel::InitEngineThreads (FLAGS_num_threads);
@@ -140,6 +147,19 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
140
147
#else
141
148
LOG (FATAL) << " Please rebuild with cmake options '-DTORCH=ON'." ;
142
149
#endif
150
+ } else if (!FLAGS_xpu_model_dir.empty ()) {
151
+ #ifdef USE_XPU
152
+ LOG (INFO) << " Reading XPU WeNet model weight from " << FLAGS_xpu_model_dir;
153
+ auto model = std::make_shared<XPUAsrModel>();
154
+ model->SetEngineThreads (FLAGS_num_threads);
155
+ model->SetDeviceId (FLAGS_device_id);
156
+ model->Read (FLAGS_xpu_model_dir);
157
+ resource->model = model;
158
+ #else
159
+ LOG (FATAL) << " Please rebuild with cmake options '-DXPU=ON'." ;
160
+ #endif
161
+ } else {
162
+ LOG (FATAL) << " Please set ONNX, TORCH or XPU model path!!!" ;
143
163
}
144
164
145
165
LOG (INFO) << " Reading unit table " << FLAGS_unit_path;
@@ -186,6 +206,7 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
186
206
post_process_opts.lowercase = FLAGS_lowercase;
187
207
resource->post_processor =
188
208
std::make_shared<PostProcessor>(std::move (post_process_opts));
209
+ LOG (INFO) << " Finish set PostProcessOptions. \n " ;
189
210
return resource;
190
211
}
191
212
0 commit comments