@@ -58,6 +58,7 @@ class RequestFuncInput:
58
58
output_len : int
59
59
model : str
60
60
lora_name : str
61
+ image_data : str
61
62
extra_request_body : Dict [str , Any ]
62
63
63
64
@@ -347,6 +348,11 @@ async def async_request_sglang_generate(
347
348
"logprob_start_len" : - 1 ,
348
349
** request_func_input .extra_request_body ,
349
350
}
351
+
352
+ # Add image data if available
353
+ if request_func_input .image_data :
354
+ payload ["image_data" ] = request_func_input .image_data
355
+
350
356
headers = get_auth_headers ()
351
357
352
358
output = RequestFuncOutput ()
@@ -510,6 +516,13 @@ def get_dataset(args, tokenizer):
510
516
tokenizer = tokenizer ,
511
517
args = args ,
512
518
)
519
+ elif args .dataset_name == "mmmu" :
520
+ input_requests = sample_mmmu_requests (
521
+ num_requests = args .num_prompts ,
522
+ tokenizer = tokenizer ,
523
+ fixed_output_len = args .random_output_len ,
524
+ random_sample = True ,
525
+ )
513
526
else :
514
527
raise ValueError (f"Unknown dataset: { args .dataset_name } " )
515
528
return input_requests
@@ -597,6 +610,121 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
597
610
return filename
598
611
599
612
613
+ def sample_mmmu_requests (
614
+ num_requests : int ,
615
+ tokenizer : PreTrainedTokenizerBase ,
616
+ fixed_output_len : Optional [int ] = None ,
617
+ random_sample : bool = True ,
618
+ ) -> List [Tuple [str , int , int ]]:
619
+ """
620
+ Sample requests from the MMMU dataset using HuggingFace datasets.
621
+
622
+ Args:
623
+ num_requests: Number of requests to sample.
624
+ tokenizer: Tokenizer to use for token counting.
625
+ fixed_output_len: If provided, use this fixed output length for all requests.
626
+ random_sample: Whether to randomly sample or take the first N.
627
+
628
+ Returns:
629
+ List of tuples (prompt, prompt_token_len, output_token_len).
630
+ """
631
+ try :
632
+ import base64
633
+ import io
634
+
635
+ from datasets import load_dataset
636
+ except ImportError :
637
+ raise ImportError ("Please install datasets: pip install datasets" )
638
+
639
+ print ("Loading MMMU dataset from HuggingFace..." )
640
+
641
+ try :
642
+ print ("Attempting to load MMMU Math dataset..." )
643
+ mmmu_dataset = load_dataset ("MMMU/MMMU" , "Math" , split = "test" )
644
+ print (
645
+ f"Successfully loaded MMMU Math dataset from HuggingFace with { len (mmmu_dataset )} examples"
646
+ )
647
+ except Exception as e :
648
+ print (f"Failed to load MMMU Math dataset: { e } " )
649
+ raise ValueError (f"Failed to load MMMU dataset: { e } " )
650
+
651
+ # Sample from the dataset
652
+ if len (mmmu_dataset ) > num_requests :
653
+ if random_sample :
654
+ # Random sample
655
+ indices = random .sample (range (len (mmmu_dataset )), num_requests )
656
+ sample_dataset = mmmu_dataset .select (indices )
657
+ else :
658
+ # Take first N
659
+ sample_dataset = mmmu_dataset .select (
660
+ range (min (num_requests , len (mmmu_dataset )))
661
+ )
662
+ else :
663
+ print (f"Dataset has less than { num_requests } examples, using all examples" )
664
+ sample_dataset = mmmu_dataset
665
+
666
+ print (f"Selected { len (sample_dataset )} examples for benchmarking" )
667
+
668
+ # Create prompts
669
+ filtered_dataset = []
670
+
671
+ for i , example in enumerate (sample_dataset ):
672
+ try :
673
+ # Extract image_1
674
+ image = example .get ("image_1" )
675
+
676
+ if image is not None :
677
+ if hasattr (image , "save" ):
678
+ # Convert RGBA images to RGB before encoding
679
+ if image .mode == "RGBA" :
680
+ image = image .convert ("RGB" )
681
+
682
+ # Encode image to base64
683
+ buffered = io .BytesIO ()
684
+ image .save (buffered , format = "JPEG" )
685
+ img_str = base64 .b64encode (buffered .getvalue ()).decode ("utf-8" )
686
+ image_path = f"data:image/jpeg;base64,{ img_str } "
687
+ else :
688
+ continue
689
+
690
+ # Extract the question
691
+ question = example .get ("question" )
692
+
693
+ # Create the prompt with image, question
694
+ prompt = f"Question: { question } \n \n Answer: "
695
+ prompt = tokenizer .apply_chat_template (
696
+ [
697
+ {
698
+ "role" : "user" ,
699
+ "content" : [
700
+ {"type" : "image_url" , "image_url" : {"url" : image_path }},
701
+ {"type" : "text" , "text" : prompt },
702
+ ],
703
+ }
704
+ ],
705
+ add_generation_prompt = True ,
706
+ tokenize = False ,
707
+ )
708
+ prompt = f"<image>{ image_path } </image>{ prompt } "
709
+
710
+ # Calculate token lengths
711
+ # Note: This is approximate since we're not rendering the actual image tokens
712
+ prompt_token_ids = tokenizer .encode (prompt )
713
+ prompt_len = (
714
+ len (prompt_token_ids ) + 512
715
+ ) # Add estimate for image tokens
716
+
717
+ output_len = fixed_output_len if fixed_output_len is not None else 256
718
+
719
+ filtered_dataset .append ((prompt , prompt_len , output_len ))
720
+
721
+ except Exception as e :
722
+ print (f"Error processing example { i } : { e } " )
723
+
724
+ print (f"\n Created { len (filtered_dataset )} MMMU prompts" )
725
+ return filtered_dataset
726
+
727
+
600
728
def sample_sharegpt_requests (
601
729
dataset_path : str ,
602
730
num_requests : int ,
@@ -1004,6 +1132,15 @@ async def limited_request_func(request_func_input, pbar):
1004
1132
else :
1005
1133
lora_name = None
1006
1134
1135
+ if "<image>" in test_prompt :
1136
+ import re
1137
+
1138
+ image_match = re .search (r"<image>(.*?)</image>(.*)" , test_prompt )
1139
+ image_data = image_match .group (1 ) if image_match else None
1140
+ test_prompt = image_match .group (2 ) if image_match else test_prompt
1141
+ else :
1142
+ image_data = None
1143
+
1007
1144
# Create the test input once
1008
1145
test_input = RequestFuncInput (
1009
1146
model = model_id ,
@@ -1012,6 +1149,7 @@ async def limited_request_func(request_func_input, pbar):
1012
1149
prompt_len = test_prompt_len ,
1013
1150
output_len = min (test_output_len , 32 ),
1014
1151
lora_name = lora_name ,
1152
+ image_data = image_data ,
1015
1153
extra_request_body = extra_request_body ,
1016
1154
)
1017
1155
@@ -1063,13 +1201,23 @@ async def limited_request_func(request_func_input, pbar):
1063
1201
else :
1064
1202
lora_name = None
1065
1203
1204
+ if "<image>" in prompt :
1205
+ import re
1206
+
1207
+ image_match = re .search (r"<image>(.*?)</image>(.*)" , prompt )
1208
+ image_data = image_match .group (1 ) if image_match else None
1209
+ prompt = image_match .group (2 ) if image_match else prompt
1210
+ else :
1211
+ image_data = None
1212
+
1066
1213
request_func_input = RequestFuncInput (
1067
1214
model = model_id ,
1068
1215
prompt = prompt ,
1069
1216
api_url = api_url ,
1070
1217
prompt_len = prompt_len ,
1071
1218
output_len = output_len ,
1072
1219
lora_name = lora_name ,
1220
+ image_data = image_data ,
1073
1221
extra_request_body = extra_request_body ,
1074
1222
)
1075
1223
tasks .append (
@@ -1444,7 +1592,7 @@ def __call__(self, parser, namespace, values, option_string=None):
1444
1592
"--dataset-name" ,
1445
1593
type = str ,
1446
1594
default = "sharegpt" ,
1447
- choices = ["sharegpt" , "random" , "random-ids" , "generated-shared-prefix" ],
1595
+ choices = ["sharegpt" , "random" , "random-ids" , "generated-shared-prefix" , "mmmu" ],
1448
1596
help = "Name of the dataset to benchmark on." ,
1449
1597
)
1450
1598
parser .add_argument (
0 commit comments