@@ -490,14 +490,15 @@ def get_dataset(args, tokenizer):
490
490
prompt_suffix = args .prompt_suffix ,
491
491
apply_chat_template = args .apply_chat_template ,
492
492
)
493
- elif args .dataset_name == "random" :
493
+ elif args .dataset_name . startswith ( "random" ) :
494
494
input_requests = sample_random_requests (
495
495
input_len = args .random_input_len ,
496
496
output_len = args .random_output_len ,
497
497
num_prompts = args .num_prompts ,
498
498
range_ratio = args .random_range_ratio ,
499
499
tokenizer = tokenizer ,
500
500
dataset_path = args .dataset_path ,
501
+ random_sample = args .dataset_name == "random" ,
501
502
)
502
503
elif args .dataset_name == "generated-shared-prefix" :
503
504
input_requests = sample_generated_shared_prefix_requests (
@@ -687,6 +688,7 @@ def sample_random_requests(
687
688
range_ratio : float ,
688
689
tokenizer : PreTrainedTokenizerBase ,
689
690
dataset_path : str ,
691
+ random_sample : bool = True ,
690
692
) -> List [Tuple [str , int , int ]]:
691
693
692
694
input_lens = np .random .randint (
@@ -700,11 +702,15 @@ def sample_random_requests(
700
702
size = num_prompts ,
701
703
)
702
704
703
- if True :
705
+ if random_sample :
704
706
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
705
707
706
708
# Download sharegpt if necessary
707
709
if not os .path .isfile (dataset_path ):
710
+ print (
711
+ "If you do not want to randomly sample from a dataset,"
712
+ " please use --dataset-name random-ids."
713
+ )
708
714
dataset_path = download_and_cache_file (SHAREGPT_URL )
709
715
710
716
# Load the dataset.
@@ -1223,7 +1229,7 @@ async def limited_request_func(request_func_input, pbar):
1223
1229
output_file_name = args .output_file
1224
1230
else :
1225
1231
now = datetime .now ().strftime ("%m%d" )
1226
- if args .dataset_name == "random" :
1232
+ if args .dataset_name . startswith ( "random" ) :
1227
1233
output_file_name = f"{ args .backend } _{ now } _{ args .num_prompts } _{ args .random_input_len } _{ args .random_output_len } .jsonl"
1228
1234
else :
1229
1235
output_file_name = f"{ args .backend } _{ now } _{ args .num_prompts } _sharegpt.jsonl"
@@ -1442,7 +1448,7 @@ def __call__(self, parser, namespace, values, option_string=None):
1442
1448
"--dataset-name" ,
1443
1449
type = str ,
1444
1450
default = "sharegpt" ,
1445
- choices = ["sharegpt" , "random" , "generated-shared-prefix" ],
1451
+ choices = ["sharegpt" , "random" , "random-ids" , " generated-shared-prefix" ],
1446
1452
help = "Name of the dataset to benchmark on." ,
1447
1453
)
1448
1454
parser .add_argument (
0 commit comments