Skip to content

diar_streaming_sortformer_4spk-v2.1 ONNX export #15536

@wisdomtooth546

Description

@wisdomtooth546

Describe the bug

I've run into a couple of issues while attempting to run onnx inference on diar_streaming_sortformer_4spk-v2.1.onnx.

  1. When I attempt to export with the model with streaming_export(). It complains about the chunk dimension. The model expects a feature dimension of 128 while the example input defined inside streaming_export() sets it as 80 (this might be referring to how the model processes the audio in 80ms frames) causing the export to fail. I was able to fix it by patching dim -1 for chunk to be self.cfg.preprocessor.features.
  2. I run into the following error while attempting to run the ONNX model exported this way.

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'/Reshape' Status Message: /Users/cloudtest/vss/_work/1/s/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper(const TensorShape &, TensorShapeVector &, bool) input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{0,512}, requested shape:{1,512,512}

Steps/Code to reproduce bug

  1. ONNX export code
chunk_right_context = 60
spkcache_update_period = 344


model.streaming_mode = True
model.sortformer_modules.chunk_len = chunk_length
model.sortformer_modules.chunk_right_context = chunk_right_context
model.sortformer_modules.spkcache_update_period = spkcache_update_period
model.async_streaming = True


B = 4
n_mels = model.cfg.preprocessor.features
T = 220  # feature frames in this chunk

chunk = torch.randn(B, T, n_mels, device=model.device, dtype=torch.float32)
chunk_lengths = torch.full((B,), T, device=model.device, dtype=torch.int64)

spkcache = torch.randn(B, 188, 512, device=model.device, dtype=torch.float32)
spkcache_lengths = torch.tensor([40, 188, 0, 68], device=model.device, dtype=torch.int64)

fifo = torch.randn(B, 188, 512, device=model.device, dtype=torch.float32)
fifo_lengths = torch.tensor([50, 88, 0, 90], device=model.device, dtype=torch.int64)

input_example = (chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths)



model.export(
    "diar_streaming_sortformer_4spk-v2.1.onnx",
    input_example=input_example,
    dynamic_axes={
        "chunk": {0: "batch_size", 1: "chunk_length"},
        "chunk_lengths": {0: "batch_size"},
        "spkcache": {0: "batch_size", 1: "spk_cache_length"},
        "spkcache_lengths": {0: "batch_size"},
        "fifo": {0: "batch_size", 1: "fifo_length"},
        "fifo_lengths": {0: "batch_size"},
    },
)
  1. Run ONNX inference

for chunk_id, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader:

    spkcache_np = streaming_state.spkcache.cpu().numpy().astype(np.float32)
    fifo_np = streaming_state.fifo.cpu().numpy().astype(np.float32)

    spkcache_lengths_np = streaming_state.spkcache_lengths.cpu().numpy()
    fifo_lengths_np = streaming_state.fifo_lengths.cpu().numpy()


    input_data = {
        "chunk": chunk_feat_seq_t.cpu().numpy(),
        "chunk_lengths": feat_lengths.cpu().numpy(),
        "spkcache": spkcache_np,
        "spkcache_lengths": spkcache_lengths_np,
        "fifo": fifo_np,
        "fifo_lengths": fifo_lengths_np,
    }
    # Run with ONNX
    spkcache_fifo_chunk_preds, chunk_pre_encode_embs, chunk_pre_encode_lengths = sess.run(None, input_data)
    
    # Update the states
    streaming_state, chunk_preds = model.sortformer_modules.streaming_update_async(
        streaming_state=streaming_state,
        chunk=torch.tensor(chunk_pre_encode_embs),
        chunk_lengths=torch.tensor(chunk_pre_encode_lengths),
        preds=torch.tensor(spkcache_fifo_chunk_preds),
        lc=round(left_offset / model.encoder.subsampling_factor),
        rc=math.ceil(right_offset / model.encoder.subsampling_factor),
    )

Expected behavior

Successful ONNX graph execution

Environment overview (please complete the following information)

  • Environment location: Mac M4 air 16GB, python 3.12, CPU
  • Method of NeMo install: uv add nemo_toolkit[asr]

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version macOS Tahoe 26.3.1
  • PyTorch version 2.10
  • Python version 3.12

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions