From aa6b750af2a3e943a522297edd94588c5d6f8ec5 Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 25 Apr 2025 18:58:20 +0000 Subject: [PATCH 1/6] add mtbench to serving bench --- benchmarks/benchmark_dataset.py | 45 +++++++++++++++++++++++++++++++++ benchmarks/benchmark_serving.py | 6 ++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index ccbc6c022f1..2c04a4539a1 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -771,6 +771,51 @@ def sample(self, return sampled_requests +# ----------------------------------------------------------------------------- +# MT-Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MTBenchDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 # noqa: E501 + """ + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "philschmid/mt-bench", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item['turns'][0] + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + # ----------------------------------------------------------------------------- # AIMO Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index da124e1a81b..901f803e150 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -52,7 +52,8 @@ from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, ConversationDataset, HuggingFaceDataset, - InstructCoderDataset, RandomDataset, + InstructCoderDataset, MTBenchDataset, + RandomDataset, SampleRequest, ShareGPTDataset, SonnetDataset, VisionArenaDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json @@ -595,6 +596,9 @@ def main(args: argparse.Namespace): elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_class = InstructCoderDataset args.hf_split = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MTBenchDataset + args.hf_split = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_class = ConversationDataset elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: From b527e39dd0424406099368f50e9d648154d25e54 Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 25 Apr 2025 21:21:29 +0000 Subject: [PATCH 2/6] add chat template for EAGLE --- benchmarks/benchmark_dataset.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 2c04a4539a1..12d295b15e0 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -805,6 +805,13 @@ def sample(self, if len(sampled_requests) >= num_requests: break prompt = item['turns'][0] + + # apply template + prompt=tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], tokenize=False) + prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( SampleRequest( From fe706ae59e97ef71560e0fb1f00cbcc2a7063458 Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 25 Apr 2025 21:23:58 +0000 Subject: [PATCH 3/6] add add_generation_prompt=True --- benchmarks/benchmark_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 12d295b15e0..aabae2c8218 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -810,7 +810,7 @@ def sample(self, prompt=tokenizer.apply_chat_template([{ "role": "user", "content": prompt - }], tokenize=False) + }], add_generation_prompt=True, tokenize=False) prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( From 07bfe6674d77a3e5b08cfaeadb4b859933e5b100 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:33:23 -0400 Subject: [PATCH 4/6] Update benchmarks/benchmark_dataset.py Co-authored-by: Woosuk Kwon --- benchmarks/benchmark_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index aabae2c8218..d08cad95df3 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -783,7 +783,7 @@ class MTBenchDataset(HuggingFaceDataset): We create a single turn dataset for MT-Bench. This is similar to Spec decoding benchmark setup in vLLM - https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 # noqa: E501 + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 # noqa: E501 """ DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM From b4c3793551cfe44580a512da5490014f61ce19d1 Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:03:59 +0000 Subject: [PATCH 5/6] fix linter --- benchmarks/benchmark_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index d08cad95df3..7d38ea68dd5 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -783,8 +783,8 @@ class MTBenchDataset(HuggingFaceDataset): We create a single turn dataset for MT-Bench. This is similar to Spec decoding benchmark setup in vLLM - https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 # noqa: E501 - """ + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM SUPPORTED_DATASET_PATHS = { From 8a43105976e261e0a14c81d4a81f59921c552093 Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:19:04 +0000 Subject: [PATCH 6/6] pre-commit --- benchmarks/benchmark_dataset.py | 8 +++++--- benchmarks/benchmark_serving.py | 7 +++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 7d38ea68dd5..9c614baf1f0 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -805,12 +805,14 @@ def sample(self, if len(sampled_requests) >= num_requests: break prompt = item['turns'][0] - + # apply template - prompt=tokenizer.apply_chat_template([{ + prompt = tokenizer.apply_chat_template([{ "role": "user", "content": prompt - }], add_generation_prompt=True, tokenize=False) + }], + add_generation_prompt=True, + tokenize=False) prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 901f803e150..c236d64261d 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -52,10 +52,9 @@ from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, ConversationDataset, HuggingFaceDataset, - InstructCoderDataset, MTBenchDataset, - RandomDataset, - SampleRequest, ShareGPTDataset, SonnetDataset, - VisionArenaDataset) + InstructCoderDataset, MTBenchDataset, + RandomDataset, SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000