Skip to content

Commit 4e8e983

Browse files
zhyncsxwu-intel
authored andcommitted
feat: support loogle eval (sgl-project#6190)
1 parent 26651d6 commit 4e8e983

File tree

3 files changed

+158
-1
lines changed

3 files changed

+158
-1
lines changed

python/sglang/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Code Structures
22

3+
- `eval`: The evaluation utilities.
34
- `lang`: The frontend language.
45
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
56
- `test`: The test utilities.
@@ -11,6 +12,5 @@
1112
- `check_env.py`: Check the environment variables and dependencies.
1213
- `global_config.py`: The global configs and constants.
1314
- `launch_server.py`: The entry point for launching the local server.
14-
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
1515
- `utils.py`: Common utilities.
1616
- `version.py`: Version info.
File renamed without changes.

python/sglang/eval/loogle_eval.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import argparse
2+
import asyncio
3+
import os
4+
import pickle
5+
from pathlib import Path
6+
from typing import List
7+
8+
import openai
9+
import torch
10+
from bert_score import BERTScorer
11+
from datasets import load_dataset
12+
from tqdm import tqdm
13+
14+
15+
def get_client(api_url: str) -> openai.AsyncOpenAI:
16+
if os.getenv("OPENAI_API_KEY") is None:
17+
os.environ["OPENAI_API_KEY"] = "EMPTY"
18+
return openai.AsyncOpenAI(base_url=api_url)
19+
20+
21+
def get_dataset():
22+
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
23+
24+
25+
async def fetch_response(
26+
client: openai.AsyncOpenAI,
27+
context: str,
28+
question: str,
29+
semaphore: asyncio.Semaphore,
30+
index: int,
31+
model: str,
32+
output_dir: Path,
33+
):
34+
output_file = output_dir / f"response_{index}.pkl"
35+
if output_file.exists():
36+
return
37+
38+
prompt = (
39+
"Please answer the question based on the long texts below.\n"
40+
f"{context}\n"
41+
f"Question: {question}\n"
42+
"Answer:"
43+
)
44+
messages = [
45+
{"role": "system", "content": "You are a helpful assistant."},
46+
{"role": "user", "content": prompt},
47+
]
48+
49+
async with semaphore:
50+
try:
51+
response = await client.chat.completions.create(
52+
model=model,
53+
messages=messages,
54+
temperature=0.0,
55+
max_tokens=512,
56+
)
57+
except openai.BadRequestError as e:
58+
with open(output_file, "wb") as f:
59+
pickle.dump({"error": str(e)}, f)
60+
return
61+
62+
with open(output_file, "wb") as f:
63+
pickle.dump(response, f)
64+
65+
66+
async def benchmark(args):
67+
dataset = get_dataset()
68+
output_dir = Path(args.output_dir)
69+
output_dir.mkdir(parents=True, exist_ok=True)
70+
71+
client = get_client(args.api_url)
72+
semaphore = asyncio.Semaphore(args.max_concurrency)
73+
74+
tasks: List[asyncio.Task] = []
75+
for idx, ex in enumerate(dataset):
76+
tasks.append(
77+
asyncio.create_task(
78+
fetch_response(
79+
client,
80+
ex["context"],
81+
ex["question"],
82+
semaphore,
83+
idx,
84+
args.model,
85+
output_dir,
86+
)
87+
)
88+
)
89+
90+
for _ in tqdm(
91+
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
92+
):
93+
await _
94+
95+
96+
def analyse(args):
97+
dataset = get_dataset()
98+
output_dir = Path(args.output_dir)
99+
100+
device = "cuda" if torch.cuda.is_available() else "cpu"
101+
scorer = BERTScorer(lang="en", device=device)
102+
103+
hyps: List[str] = []
104+
refs: List[str] = []
105+
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
106+
pkl_file = output_dir / f"response_{idx}.pkl"
107+
if not pkl_file.exists():
108+
raise FileNotFoundError(pkl_file)
109+
110+
response = pickle.load(open(pkl_file, "rb"))
111+
if isinstance(response, dict) and "error" in response:
112+
continue
113+
114+
hyps.append(response.choices[0].message.content.strip())
115+
refs.append(ex["answer"])
116+
117+
if not hyps:
118+
print("No valid responses to score!")
119+
return
120+
121+
batch_size = 64
122+
all_f1: List[float] = []
123+
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
124+
h_batch = hyps[i : i + batch_size]
125+
r_batch = refs[i : i + batch_size]
126+
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
127+
all_f1.extend([float(x) for x in f1_scores])
128+
129+
avg = sum(all_f1) / len(all_f1)
130+
print(f"Average BERTScore (F1): {avg:.2%}")
131+
132+
133+
if __name__ == "__main__":
134+
parser = argparse.ArgumentParser(
135+
description="Run benchmark and evaluation in one go."
136+
)
137+
parser.add_argument(
138+
"--api-url",
139+
default="http://127.0.0.1:30000/v1",
140+
help="OpenAI‑compatible API base URL",
141+
)
142+
parser.add_argument(
143+
"--model",
144+
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
145+
help="Model name or ID",
146+
)
147+
parser.add_argument(
148+
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
149+
)
150+
parser.add_argument(
151+
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
152+
)
153+
args = parser.parse_args()
154+
155+
asyncio.run(benchmark(args))
156+
157+
analyse(args)

0 commit comments

Comments
 (0)