-
Notifications
You must be signed in to change notification settings - Fork 3.4k
diar_streaming_sortformer_4spk-v2.1 ONNX export #15536
Description
Describe the bug
I've run into a couple of issues while attempting to run onnx inference on diar_streaming_sortformer_4spk-v2.1.onnx.
- When I attempt to export with the model with
streaming_export(). It complains about thechunkdimension. The model expects a feature dimension of 128 while the example input defined insidestreaming_export()sets it as 80 (this might be referring to how the model processes the audio in 80ms frames) causing the export to fail. I was able to fix it by patching dim -1 for chunk to beself.cfg.preprocessor.features. - I run into the following error while attempting to run the ONNX model exported this way.
RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'/Reshape' Status Message: /Users/cloudtest/vss/_work/1/s/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper(const TensorShape &, TensorShapeVector &, bool) input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{0,512}, requested shape:{1,512,512}
Steps/Code to reproduce bug
- ONNX export code
chunk_right_context = 60
spkcache_update_period = 344
model.streaming_mode = True
model.sortformer_modules.chunk_len = chunk_length
model.sortformer_modules.chunk_right_context = chunk_right_context
model.sortformer_modules.spkcache_update_period = spkcache_update_period
model.async_streaming = True
B = 4
n_mels = model.cfg.preprocessor.features
T = 220 # feature frames in this chunk
chunk = torch.randn(B, T, n_mels, device=model.device, dtype=torch.float32)
chunk_lengths = torch.full((B,), T, device=model.device, dtype=torch.int64)
spkcache = torch.randn(B, 188, 512, device=model.device, dtype=torch.float32)
spkcache_lengths = torch.tensor([40, 188, 0, 68], device=model.device, dtype=torch.int64)
fifo = torch.randn(B, 188, 512, device=model.device, dtype=torch.float32)
fifo_lengths = torch.tensor([50, 88, 0, 90], device=model.device, dtype=torch.int64)
input_example = (chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths)
model.export(
"diar_streaming_sortformer_4spk-v2.1.onnx",
input_example=input_example,
dynamic_axes={
"chunk": {0: "batch_size", 1: "chunk_length"},
"chunk_lengths": {0: "batch_size"},
"spkcache": {0: "batch_size", 1: "spk_cache_length"},
"spkcache_lengths": {0: "batch_size"},
"fifo": {0: "batch_size", 1: "fifo_length"},
"fifo_lengths": {0: "batch_size"},
},
)
- Run ONNX inference
for chunk_id, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader:
spkcache_np = streaming_state.spkcache.cpu().numpy().astype(np.float32)
fifo_np = streaming_state.fifo.cpu().numpy().astype(np.float32)
spkcache_lengths_np = streaming_state.spkcache_lengths.cpu().numpy()
fifo_lengths_np = streaming_state.fifo_lengths.cpu().numpy()
input_data = {
"chunk": chunk_feat_seq_t.cpu().numpy(),
"chunk_lengths": feat_lengths.cpu().numpy(),
"spkcache": spkcache_np,
"spkcache_lengths": spkcache_lengths_np,
"fifo": fifo_np,
"fifo_lengths": fifo_lengths_np,
}
# Run with ONNX
spkcache_fifo_chunk_preds, chunk_pre_encode_embs, chunk_pre_encode_lengths = sess.run(None, input_data)
# Update the states
streaming_state, chunk_preds = model.sortformer_modules.streaming_update_async(
streaming_state=streaming_state,
chunk=torch.tensor(chunk_pre_encode_embs),
chunk_lengths=torch.tensor(chunk_pre_encode_lengths),
preds=torch.tensor(spkcache_fifo_chunk_preds),
lc=round(left_offset / model.encoder.subsampling_factor),
rc=math.ceil(right_offset / model.encoder.subsampling_factor),
)
Expected behavior
Successful ONNX graph execution
Environment overview (please complete the following information)
- Environment location: Mac M4 air 16GB, python 3.12, CPU
- Method of NeMo install: uv add nemo_toolkit[asr]
Environment details
If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:
- OS version macOS Tahoe 26.3.1
- PyTorch version 2.10
- Python version 3.12