Skip to content

Commit f75a111

Browse files
NickLucchedbyoung18
authored andcommitted
[Misc] Benchmarks for audio models (vllm-project#16505)
Signed-off-by: NickLucche <[email protected]>
1 parent f9cf3ae commit f75a111

File tree

4 files changed

+199
-5
lines changed

4 files changed

+199
-5
lines changed

benchmarks/backend_request_func.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import io
34
import json
45
import os
56
import sys
@@ -32,6 +33,7 @@ class RequestFuncInput:
3233
extra_body: Optional[dict] = None
3334
multi_modal_content: Optional[dict] = None
3435
ignore_eos: bool = False
36+
language: Optional[str] = None
3537

3638

3739
@dataclass
@@ -436,6 +438,110 @@ async def async_request_openai_chat_completions(
436438
return output
437439

438440

441+
async def async_request_openai_audio(
442+
request_func_input: RequestFuncInput,
443+
pbar: Optional[tqdm] = None,
444+
) -> RequestFuncOutput:
445+
# Lazy import without PlaceholderModule to avoid vllm dep.
446+
import soundfile
447+
api_url = request_func_input.api_url
448+
assert api_url.endswith(
449+
("transcriptions", "translations"
450+
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
451+
"or `translations`."
452+
453+
async with aiohttp.ClientSession(trust_env=True,
454+
timeout=AIOHTTP_TIMEOUT) as session:
455+
content = [{"type": "text", "text": request_func_input.prompt}]
456+
payload = {
457+
"model": request_func_input.model_name \
458+
if request_func_input.model_name else request_func_input.model,
459+
"temperature": 0.0,
460+
"max_completion_tokens": request_func_input.output_len,
461+
"stream": True,
462+
"language": "en",
463+
# Flattened due to multipart/form-data
464+
"stream_include_usage": True,
465+
"stream_continuous_usage_stats": True
466+
}
467+
if request_func_input.extra_body:
468+
payload.update(request_func_input.extra_body)
469+
headers = {
470+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
471+
}
472+
473+
# Send audio file
474+
def to_bytes(y, sr):
475+
buffer = io.BytesIO()
476+
soundfile.write(buffer, y, sr, format="WAV")
477+
buffer.seek(0)
478+
return buffer
479+
480+
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
481+
form = aiohttp.FormData()
482+
form.add_field('file', f, content_type='audio/wav')
483+
for key, value in payload.items():
484+
form.add_field(key, str(value))
485+
486+
output = RequestFuncOutput()
487+
output.prompt_len = request_func_input.prompt_len
488+
489+
generated_text = ""
490+
ttft = 0.0
491+
st = time.perf_counter()
492+
most_recent_timestamp = st
493+
try:
494+
async with session.post(url=api_url,
495+
data=form,
496+
headers=headers) as response:
497+
if response.status == 200:
498+
async for chunk_bytes in response.content:
499+
chunk_bytes = chunk_bytes.strip()
500+
if not chunk_bytes:
501+
continue
502+
503+
chunk = chunk_bytes.decode("utf-8").removeprefix(
504+
"data: ")
505+
if chunk != "[DONE]":
506+
timestamp = time.perf_counter()
507+
data = json.loads(chunk)
508+
509+
if choices := data.get("choices"):
510+
content = choices[0]["delta"].get(
511+
"content")
512+
# First token
513+
if ttft == 0.0:
514+
ttft = timestamp - st
515+
output.ttft = ttft
516+
517+
# Decoding phase
518+
else:
519+
output.itl.append(
520+
timestamp - most_recent_timestamp)
521+
522+
generated_text += content or ""
523+
elif usage := data.get("usage"):
524+
output.output_tokens = usage.get(
525+
"completion_tokens")
526+
527+
most_recent_timestamp = timestamp
528+
529+
output.generated_text = generated_text
530+
output.success = True
531+
output.latency = most_recent_timestamp - st
532+
else:
533+
output.error = response.reason or ""
534+
output.success = False
535+
except Exception:
536+
output.success = False
537+
exc_info = sys.exc_info()
538+
output.error = "".join(traceback.format_exception(*exc_info))
539+
540+
if pbar:
541+
pbar.update(1)
542+
return output
543+
544+
439545
def get_model(pretrained_model_name_or_path: str) -> str:
440546
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
441547
from modelscope import snapshot_download
@@ -493,6 +599,7 @@ def get_tokenizer(
493599
"deepspeed-mii": async_request_deepspeed_mii,
494600
"openai": async_request_openai_completions,
495601
"openai-chat": async_request_openai_chat_completions,
602+
"openai-audio": async_request_openai_audio,
496603
"tensorrt-llm": async_request_trt_llm,
497604
"scalellm": async_request_openai_completions,
498605
"sglang": async_request_openai_completions,

benchmarks/benchmark_dataset.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class SampleRequest:
6464

6565
class BenchmarkDataset(ABC):
6666
DEFAULT_SEED = 0
67+
IS_MULTIMODAL = False
6768

6869
def __init__(
6970
self,
@@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
621622
SUPPORTED_DATASET_PATHS = {
622623
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
623624
}
625+
IS_MULTIMODAL = True
624626

625627
def sample(self,
626628
tokenizer: PreTrainedTokenizerBase,
@@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
685687
"lmarena-ai/vision-arena-bench-v0.1":
686688
lambda x: x["turns"][0][0]["content"]
687689
}
690+
IS_MULTIMODAL = True
688691

689692
def sample(
690693
self,
@@ -815,3 +818,80 @@ def sample(self,
815818
))
816819
self.maybe_oversample_requests(sampled_requests, num_requests)
817820
return sampled_requests
821+
822+
823+
# -----------------------------------------------------------------------------
824+
# ASR Dataset Implementation
825+
# -----------------------------------------------------------------------------
826+
827+
828+
class ASRDataset(HuggingFaceDataset):
829+
"""
830+
Dataset class for processing a ASR dataset for transcription.
831+
Tested on the following set:
832+
833+
+----------------+----------------------------------------+--------------------------+-----------------------------+
834+
| Dataset | Domain | Speaking Style | hf-subset |
835+
+----------------+----------------------------------------+--------------------------+-----------------------------+
836+
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
837+
| | | | release3-speaker-adaptation |
838+
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
839+
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
840+
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
841+
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
842+
| AMI | Meetings | Spontaneous | ihm, sdm |
843+
+----------------+----------------------------------------+--------------------------+-----------------------------+
844+
845+
""" # noqa: E501
846+
SUPPORTED_DATASET_PATHS = {
847+
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
848+
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
849+
}
850+
851+
DEFAULT_OUTPUT_LEN = 128
852+
IS_MULTIMODAL = True
853+
854+
# TODO Whisper-specific. Abstract interface when more models are supported.
855+
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
856+
"<|notimestamps|>"
857+
skip_long_audios: bool = True
858+
859+
def sample(
860+
self,
861+
tokenizer: PreTrainedTokenizerBase,
862+
num_requests: int,
863+
output_len: Optional[int] = None,
864+
**kwargs,
865+
) -> list:
866+
import librosa
867+
output_len = (output_len
868+
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
869+
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
870+
prompt_len = len(tokenizer(prompt).input_ids)
871+
sampled_requests = []
872+
skipped = 0
873+
for item in self.data:
874+
if len(sampled_requests) >= num_requests:
875+
break
876+
audio = item["audio"]
877+
y, sr = audio["array"], audio["sampling_rate"]
878+
duration_s = librosa.get_duration(y=y, sr=sr)
879+
# Whisper max supported duration
880+
if self.skip_long_audios and duration_s > 30:
881+
skipped += 1
882+
continue
883+
884+
mm_content = {"audio": (y, sr)}
885+
sampled_requests.append(
886+
SampleRequest(
887+
prompt=prompt,
888+
prompt_len=prompt_len,
889+
expected_output_len=output_len,
890+
multi_modal_data=mm_content,
891+
))
892+
if skipped:
893+
logger.warning("%d samples discarded from dataset due to" \
894+
" their length being greater than" \
895+
" what Whisper supports.", skipped)
896+
self.maybe_oversample_requests(sampled_requests, num_requests)
897+
return sampled_requests

benchmarks/benchmark_serving.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
except ImportError:
5151
from argparse import ArgumentParser as FlexibleArgumentParser
5252

53-
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
53+
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
5454
ConversationDataset, HuggingFaceDataset,
5555
InstructCoderDataset, RandomDataset,
5656
SampleRequest, ShareGPTDataset, SonnetDataset,
@@ -274,10 +274,6 @@ async def benchmark(
274274
input_requests[0].expected_output_len, \
275275
input_requests[0].multi_modal_data
276276

277-
if backend != "openai-chat" and test_mm_content is not None:
278-
# multi-modal benchmark is only available on OpenAI Chat backend.
279-
raise ValueError(
280-
"Multi-modal content is only supported on 'openai-chat' backend.")
281277
assert test_mm_content is None or isinstance(test_mm_content, dict)
282278
test_input = RequestFuncInput(
283279
model=model_id,
@@ -604,6 +600,9 @@ def main(args: argparse.Namespace):
604600
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
605601
dataset_class = AIMODataset
606602
args.hf_split = "train"
603+
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
604+
dataset_class = ASRDataset
605+
args.hf_split = "train"
607606
else:
608607
supported_datasets = set([
609608
dataset_name for cls in HuggingFaceDataset.__subclasses__()
@@ -615,6 +614,13 @@ def main(args: argparse.Namespace):
615614
f" from one of following: {supported_datasets}. "
616615
"Please consider contributing if you would "
617616
"like to add support for additional dataset formats.")
617+
618+
if (dataset_class.IS_MULTIMODAL and backend not in \
619+
["openai-chat", "openai-audio"]):
620+
# multi-modal benchmark is only available on OpenAI Chat backend.
621+
raise ValueError(
622+
"Multi-modal content is only supported on 'openai-chat' and " \
623+
"'openai-audio' backend.")
618624
input_requests = dataset_class(
619625
dataset_path=args.dataset_path,
620626
dataset_subset=args.hf_subset,

tests/entrypoints/openai/correctness/test_transcription_api_correctness.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def test_wer_correctness(model_name,
150150
expected_wer,
151151
n_examples=-1,
152152
max_concurrent_request=None):
153+
# TODO refactor to use `ASRDataset`
153154
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
154155
dataset = load_hf_dataset(dataset_repo)
155156

0 commit comments

Comments
 (0)