Skip to content

Commit 691686e

Browse files
sundar24295starinkk
authored andcommitted
Perform Batch Tokenization. (sgl-project#5141)
1 parent e6176fb commit 691686e

File tree

4 files changed

+429
-25
lines changed

4 files changed

+429
-25
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import concurrent.futures
2+
import os
3+
import random
4+
import time
5+
from concurrent.futures import ProcessPoolExecutor
6+
from statistics import mean
7+
8+
import requests
9+
from tqdm import tqdm
10+
from transformers import AutoTokenizer
11+
12+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
13+
14+
###############################################################################
15+
# CONFIG
16+
###############################################################################
17+
ENDPOINT_URL = "http://127.0.0.1:30000"
18+
TOKENIZER_DIR = "/models/meta-llama/Llama-3.2-3B"
19+
20+
# Benchmark configurations
21+
NUM_REQUESTS = 10 # Total number of requests (each with BATCH_SIZE prompts)
22+
NUM_TOKENS = 32000 # Tokens per prompt
23+
BATCH_SIZE = 8 # Number of prompts per request
24+
GEN_TOKENS = 0 # Tokens to generate per prompt
25+
26+
27+
###############################################################################
28+
# REQUEST GENERATION (in parallel)
29+
###############################################################################
30+
def generate_random_prompt(index, tokenizer_dir, num_tokens):
31+
"""Generate a single random prompt with specified token count."""
32+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
33+
vocab_size = tokenizer.vocab_size
34+
35+
def generate_random_text(num_toks):
36+
random_token_ids = [random.randint(0, vocab_size - 1) for _ in range(num_toks)]
37+
return tokenizer.decode(random_token_ids, clean_up_tokenization_spaces=True)
38+
39+
random_text = generate_random_text(num_tokens)
40+
return f"Prompt {index}: {random_text}"
41+
42+
43+
def prepare_all_prompts(num_requests, batch_size, num_tokens, tokenizer_dir):
44+
"""Generate prompts for all requests in parallel."""
45+
total_prompts = num_requests * batch_size
46+
all_prompts = [None] * total_prompts
47+
max_workers = min(os.cpu_count() or 1, total_prompts)
48+
49+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
50+
futures = [
51+
executor.submit(generate_random_prompt, i, tokenizer_dir, num_tokens)
52+
for i in range(total_prompts)
53+
]
54+
for future in tqdm(
55+
concurrent.futures.as_completed(futures),
56+
total=total_prompts,
57+
desc="Generating prompts",
58+
):
59+
index = futures.index(future)
60+
all_prompts[index] = future.result()
61+
62+
batched_prompts = [
63+
all_prompts[i * batch_size : (i + 1) * batch_size] for i in range(num_requests)
64+
]
65+
66+
print(
67+
f"Generated {total_prompts} prompts with {num_tokens} tokens each, grouped into {num_requests} requests of {batch_size} prompts.\n"
68+
)
69+
return batched_prompts
70+
71+
72+
###############################################################################
73+
# HTTP CALLS
74+
###############################################################################
75+
def send_batch_request(endpoint, prompts, gen_tokens, request_id):
76+
"""Send a batch of prompts to the /generate endpoint synchronously."""
77+
sampling_params = {
78+
"max_new_tokens": gen_tokens,
79+
"temperature": 0.7,
80+
"stop": "\n",
81+
}
82+
data = {"text": prompts, "sampling_params": sampling_params}
83+
84+
start_time = time.time()
85+
try:
86+
response = requests.post(
87+
endpoint.base_url + "/generate", json=data, timeout=3600
88+
)
89+
if response.status_code != 200:
90+
error = response.json()
91+
raise RuntimeError(f"Request {request_id} failed: {error}")
92+
result = response.json()
93+
elapsed_time = (time.time() - start_time) * 1000 # Convert to ms
94+
avg_per_prompt = elapsed_time / len(prompts) if prompts else 0
95+
return request_id, elapsed_time, avg_per_prompt, True, len(prompts)
96+
except Exception as e:
97+
print(f"[Request] Error for request {request_id}: {e}")
98+
return request_id, 0, 0, False, len(prompts)
99+
100+
101+
def run_benchmark(endpoint, batched_prompts, batch_size, gen_tokens):
102+
"""Run the benchmark sequentially."""
103+
results = []
104+
num_requests = len(batched_prompts)
105+
106+
# Record start time for total latency
107+
benchmark_start_time = time.time()
108+
109+
for i, batch_prompts in enumerate(batched_prompts):
110+
request_id = i + 1
111+
assert (
112+
len(batch_prompts) == batch_size
113+
), f"Request {request_id} should have {batch_size} prompts, got {len(batch_prompts)}"
114+
115+
print(
116+
f"[Request] Sending request {request_id}/{num_requests} with {len(batch_prompts)} prompts at {int(time.time()*1000)}"
117+
)
118+
result = send_batch_request(endpoint, batch_prompts, gen_tokens, request_id)
119+
results.append(result)
120+
121+
# Calculate total latency
122+
total_latency = (time.time() - benchmark_start_time) * 1000 # Convert to ms
123+
124+
return results, total_latency
125+
126+
127+
###############################################################################
128+
# RESULTS
129+
###############################################################################
130+
def process_results(results, total_latency, num_requests):
131+
"""Process and display benchmark results."""
132+
total_time = 0
133+
successful_requests = 0
134+
failed_requests = 0
135+
request_latencies = []
136+
per_prompt_latencies = []
137+
total_prompts = 0
138+
139+
for request_id, elapsed_time, avg_per_prompt, success, batch_size in results:
140+
if success:
141+
successful_requests += 1
142+
total_prompts += batch_size
143+
request_latencies.append(elapsed_time)
144+
per_prompt_latencies.append(avg_per_prompt)
145+
total_time += elapsed_time / 1000 # Convert to seconds
146+
else:
147+
failed_requests += 1
148+
149+
avg_request_latency = mean(request_latencies) if request_latencies else 0
150+
avg_per_prompt_latency = mean(per_prompt_latencies) if per_prompt_latencies else 0
151+
throughput = total_prompts / total_time if total_time > 0 else 0
152+
153+
print("\nBenchmark Summary:")
154+
print(f" Total requests sent: {len(results)}")
155+
print(f" Total prompts sent: {total_prompts}")
156+
print(f" Successful requests: {successful_requests}")
157+
print(f" Failed requests: {failed_requests}")
158+
print(f" Total latency (all requests): {total_latency:.2f} ms")
159+
print(f" Avg per request latency: {avg_request_latency:.2f} ms")
160+
print(f" Avg per prompt latency: {avg_per_prompt_latency:.2f} ms")
161+
print(f" Throughput: {throughput:.2f} prompts/second\n")
162+
163+
164+
###############################################################################
165+
# MAIN
166+
###############################################################################
167+
def main():
168+
# Initialize endpoint
169+
endpoint = RuntimeEndpoint(ENDPOINT_URL)
170+
171+
# Generate prompts
172+
batched_prompts = prepare_all_prompts(
173+
NUM_REQUESTS, BATCH_SIZE, NUM_TOKENS, TOKENIZER_DIR
174+
)
175+
176+
# Flush cache before benchmark
177+
# endpoint.flush_cache()
178+
179+
# Run benchmark
180+
print(
181+
f"Starting benchmark: NUM_TOKENS={NUM_TOKENS}, BATCH_SIZE={BATCH_SIZE}, NUM_REQUESTS={NUM_REQUESTS}\n"
182+
)
183+
results, total_latency = run_benchmark(
184+
endpoint, batched_prompts, BATCH_SIZE, GEN_TOKENS
185+
)
186+
187+
# Process and display results
188+
process_results(results, total_latency, NUM_REQUESTS)
189+
190+
191+
if __name__ == "__main__":
192+
random.seed(0)
193+
main()
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import random
2+
import time
3+
from statistics import mean
4+
5+
from transformers import AutoTokenizer
6+
7+
# CONFIG
8+
TOKENIZER_DIR = (
9+
"/shared/public/sharing/fait360brew/training/models/meta-llama/Llama-3.2-3B"
10+
)
11+
NUM_TOKENS = 20000 # Each prompt should contain this many tokens
12+
BATCH_SIZES = [1, 2, 4, 8] # Test different batch sizes
13+
NUM_RUNS = 5 # Number of runs for each batch size to get reliable measurements
14+
15+
16+
def generate_random_prompts(num_prompts, num_tokens, tokenizer):
17+
"""Generate random prompts with specified token count."""
18+
vocab_size = tokenizer.vocab_size
19+
all_prompts = []
20+
21+
print(f"Generating {num_prompts} random prompts with {num_tokens} tokens each...")
22+
for i in range(num_prompts):
23+
# Generate random token IDs - this directly gives us the exact token count
24+
random_token_ids = [
25+
random.randint(0, vocab_size - 1) for _ in range(num_tokens)
26+
]
27+
random_text = tokenizer.decode(
28+
random_token_ids, clean_up_tokenization_spaces=True
29+
)
30+
31+
prompt = f"Prompt {i}: {random_text}"
32+
tokens = tokenizer.encode(prompt)
33+
print(f" Prompt {i}: {len(tokens)} tokens")
34+
all_prompts.append(prompt)
35+
36+
return all_prompts
37+
38+
39+
def benchmark_sequential_vs_batch(prompts, batch_size, tokenizer):
40+
"""Compare sequential vs batch tokenization for a given batch size."""
41+
42+
# Sequential tokenization using encode()
43+
sequential_times = []
44+
for run in range(NUM_RUNS):
45+
batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison
46+
47+
start_time = time.time()
48+
for prompt in batch_prompts:
49+
tokens = tokenizer.encode(prompt)
50+
sequential_time = (time.time() - start_time) * 1000
51+
sequential_times.append(sequential_time)
52+
53+
# Batch tokenization using tokenizer()
54+
batch_times = []
55+
for run in range(NUM_RUNS):
56+
batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison
57+
58+
start_time = time.time()
59+
tokens = tokenizer(batch_prompts)
60+
batch_time = (time.time() - start_time) * 1000
61+
batch_times.append(batch_time)
62+
63+
return {
64+
"batch_size": batch_size,
65+
"avg_sequential_ms": mean(sequential_times),
66+
"avg_batch_ms": mean(batch_times),
67+
"speedup_factor": (
68+
mean(sequential_times) / mean(batch_times) if mean(batch_times) > 0 else 0
69+
),
70+
"sequential_runs": sequential_times,
71+
"batch_runs": batch_times,
72+
}
73+
74+
75+
def main():
76+
print("Tokenizer Benchmark: Sequential vs Batch Processing")
77+
print("-" * 60)
78+
print(f"Tokenizer: {TOKENIZER_DIR}")
79+
print(f"Tokens per prompt: {NUM_TOKENS}")
80+
print(f"Number of runs per batch size: {NUM_RUNS}")
81+
print("-" * 60)
82+
83+
# Load tokenizer once for all operations
84+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
85+
86+
# The largest batch size determines how many prompts we need
87+
max_batch_size = max(BATCH_SIZES)
88+
all_prompts = generate_random_prompts(max_batch_size, NUM_TOKENS, tokenizer)
89+
90+
results = []
91+
print("\nRunning benchmark...")
92+
93+
for batch_size in BATCH_SIZES:
94+
print(f"\nBenchmarking batch size: {batch_size}")
95+
result = benchmark_sequential_vs_batch(all_prompts, batch_size, tokenizer)
96+
results.append(result)
97+
98+
print(f" Sequential tokenization (encode):")
99+
for i, run_time in enumerate(result["sequential_runs"]):
100+
print(f" Run {i+1}: {run_time:.2f} ms")
101+
print(f" Average: {result['avg_sequential_ms']:.2f} ms")
102+
103+
print(f" Batch tokenization (tokenizer):")
104+
for i, run_time in enumerate(result["batch_runs"]):
105+
print(f" Run {i+1}: {run_time:.2f} ms")
106+
print(f" Average: {result['avg_batch_ms']:.2f} ms")
107+
108+
print(f" Speedup factor: {result['speedup_factor']:.2f}x")
109+
110+
print("\n" + "=" * 60)
111+
print("SUMMARY OF RESULTS")
112+
print("=" * 60)
113+
print(
114+
f"{'Batch Size':<10} {'Sequential (ms)':<18} {'Batch (ms)':<18} {'Speedup':<10}"
115+
)
116+
print("-" * 60)
117+
118+
for result in results:
119+
print(
120+
f"{result['batch_size']:<10} {result['avg_sequential_ms']:.2f} ms{' ' * 8} {result['avg_batch_ms']:.2f} ms{' ' * 8} {result['speedup_factor']:.2f}x"
121+
)
122+
123+
124+
if __name__ == "__main__":
125+
random.seed(0)
126+
main()

0 commit comments

Comments
 (0)