Skip to content

[Misc] Benchmarks for audio models #16505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import io
import json
import os
import sys
Expand Down Expand Up @@ -32,6 +33,7 @@ class RequestFuncInput:
extra_body: Optional[dict] = None
multi_modal_content: Optional[dict] = None
ignore_eos: bool = False
language: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -436,6 +438,110 @@ async def async_request_openai_chat_completions(
return output


async def async_request_openai_audio(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
assert api_url.endswith(
("transcriptions", "translations"
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
"or `translations`."

async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"language": "en",
# Flattened due to multipart/form-data
"stream_include_usage": True,
"stream_continuous_usage_stats": True
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}

# Send audio file
def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer

with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav')
for key, value in payload.items():
form.add_field(key, str(value))

output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url,
data=form,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)

if choices := data.get("choices"):
content = choices[0]["delta"].get(
"content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp)

generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")

most_recent_timestamp = timestamp

output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))

if pbar:
pbar.update(1)
return output


def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
from modelscope import snapshot_download
Expand Down Expand Up @@ -493,6 +599,7 @@ def get_tokenizer(
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio,
"tensorrt-llm": async_request_trt_llm,
"scalellm": async_request_openai_completions,
"sglang": async_request_openai_completions,
Expand Down
80 changes: 80 additions & 0 deletions benchmarks/benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class SampleRequest:

class BenchmarkDataset(ABC):
DEFAULT_SEED = 0
IS_MULTIMODAL = False

def __init__(
self,
Expand Down Expand Up @@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
}
IS_MULTIMODAL = True

def sample(self,
tokenizer: PreTrainedTokenizerBase,
Expand Down Expand Up @@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
}
IS_MULTIMODAL = True

def sample(
self,
Expand Down Expand Up @@ -815,3 +818,80 @@ def sample(self,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests


# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------


class ASRDataset(HuggingFaceDataset):
"""
Dataset class for processing a ASR dataset for transcription.
Tested on the following set:

+----------------+----------------------------------------+--------------------------+-----------------------------+
| Dataset | Domain | Speaking Style | hf-subset |
+----------------+----------------------------------------+--------------------------+-----------------------------+
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
| | | | release3-speaker-adaptation |
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
| AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+

""" # noqa: E501
SUPPORTED_DATASET_PATHS = {
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
}

DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True

# TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
"<|notimestamps|>"
skip_long_audios: bool = True

def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
**kwargs,
) -> list:
import librosa
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
skipped = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
break
audio = item["audio"]
y, sr = audio["array"], audio["sampling_rate"]
duration_s = librosa.get_duration(y=y, sr=sr)
# Whisper max supported duration
if self.skip_long_audios and duration_s > 30:
skipped += 1
continue

mm_content = {"audio": (y, sr)}
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
if skipped:
logger.warning("%d samples discarded from dataset due to" \
" their length being greater than" \
" what Whisper supports.", skipped)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
16 changes: 11 additions & 5 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser

from benchmark_dataset import (AIMODataset, BurstGPTDataset,
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
ConversationDataset, HuggingFaceDataset,
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
Expand Down Expand Up @@ -274,10 +274,6 @@ async def benchmark(
input_requests[0].expected_output_len, \
input_requests[0].multi_modal_data

if backend != "openai-chat" and test_mm_content is not None:
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' backend.")
assert test_mm_content is None or isinstance(test_mm_content, dict)
test_input = RequestFuncInput(
model=model_id,
Expand Down Expand Up @@ -604,6 +600,9 @@ def main(args: argparse.Namespace):
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_class = AIMODataset
args.hf_split = "train"
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
dataset_class = ASRDataset
args.hf_split = "train"
else:
supported_datasets = set([
dataset_name for cls in HuggingFaceDataset.__subclasses__()
Expand All @@ -615,6 +614,13 @@ def main(args: argparse.Namespace):
f" from one of following: {supported_datasets}. "
"Please consider contributing if you would "
"like to add support for additional dataset formats.")

if (dataset_class.IS_MULTIMODAL and backend not in \
["openai-chat", "openai-audio"]):
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' and " \
"'openai-audio' backend.")
input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_wer_correctness(model_name,
expected_wer,
n_examples=-1,
max_concurrent_request=None):
# TODO refactor to use `ASRDataset`
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
dataset = load_hf_dataset(dataset_repo)

Expand Down