Skip to content

Commit 3171b05

Browse files
Qiaolin-Yulifuhuang
authored andcommitted
Fix flaky issues of lora and add multi batch tests (sgl-project#5957)
1 parent 0630548 commit 3171b05

File tree

4 files changed

+205
-96
lines changed

4 files changed

+205
-96
lines changed

python/sglang/srt/lora/lora_manager.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,15 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
156156
# set up batch info shared by all lora modules
157157
bs = forward_batch.batch_size
158158

159-
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
160-
# Do in-place updates when CUDA graph is enabled. Note that
161-
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
162-
# will also use these preallocated buffers, no matter whether
163-
# the batch can use CUDA graph or not.
159+
if (
160+
hasattr(self, "max_bs_in_cuda_graph")
161+
and bs <= self.max_bs_in_cuda_graph
162+
and forward_batch.forward_mode.is_cuda_graph()
163+
):
164+
# Do in-place updates when CUDA graph is enabled and the batch forward mode
165+
# could use CUDA graph.
164166
self.cuda_graph_batch_info.bs = bs
165-
if forward_batch.forward_mode.is_extend():
166-
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
167-
forward_batch.extend_seq_lens
168-
)
169-
else:
170-
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
167+
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
171168
torch.cumsum(
172169
self.cuda_graph_batch_info.seg_lens[:bs],
173170
dim=0,
@@ -201,10 +198,10 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
201198
max_len = int(torch.max(seg_lens))
202199
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
203200

204-
lora_ranks = torch.empty(
201+
lora_ranks = torch.zeros(
205202
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
206203
)
207-
scalings = torch.empty(
204+
scalings = torch.zeros(
208205
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
209206
)
210207
for i, lora_path in enumerate(forward_batch.lora_paths):

test/srt/models/lora/test_lora.py

Lines changed: 158 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,65 +13,176 @@
1313
# ==============================================================================
1414

1515
import multiprocessing as mp
16+
import os
17+
import random
1618
import unittest
19+
from typing import List
1720

18-
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase, run_lora_test_by_batch
21+
from utils import (
22+
ALL_OTHER_MULTI_LORA_MODELS,
23+
CI_MULTI_LORA_MODELS,
24+
TORCH_DTYPES,
25+
LoRAModelCase,
26+
)
1927

20-
from sglang.test.test_utils import CustomTestCase
28+
from sglang.test.runners import HFRunner, SRTRunner
29+
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
2130

22-
PROMPTS = [
31+
TEST_MULTIPLE_BATCH_PROMPTS = [
2332
"""
24-
### Instruction:
25-
Write a poem about the transformers Python library.
26-
Mention the word "large language models" in that poem.
27-
### Response:
28-
The Transformers are large language models,
29-
They're used to make predictions on text.
30-
""",
33+
### Instruction:
34+
Tell me about llamas and alpacas
35+
### Response:
36+
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
37+
### Question 2:
38+
What do you know about llamas?
39+
### Answer:
40+
""",
41+
"""
42+
### Instruction:
43+
Write a poem about the transformers Python library.
44+
Mention the word "large language models" in that poem.
45+
### Response:
46+
The Transformers are large language models,
47+
They're used to make predictions on text.
48+
""",
3149
"AI is a field of computer science focused on",
32-
]
33-
34-
LORA_MODELS_WITH_NONE = [
35-
LoRAModelCase(
36-
base="meta-llama/Llama-3.1-8B-Instruct",
37-
adaptors=[
38-
LoRAAdaptor(
39-
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
40-
),
41-
LoRAAdaptor(
42-
name=None,
43-
),
44-
],
45-
max_loras_per_batch=2,
46-
),
47-
LoRAModelCase(
48-
base="meta-llama/Llama-3.1-8B-Instruct",
49-
adaptors=[
50-
LoRAAdaptor(
51-
name=None,
52-
),
53-
LoRAAdaptor(
54-
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
55-
),
56-
],
57-
max_loras_per_batch=2,
58-
),
50+
"Computer science is the study of",
51+
"Write a short story.",
52+
"What are the main components of a computer?",
5953
]
6054

6155

6256
class TestLoRA(CustomTestCase):
63-
def test_lora_batch_with_none(self):
64-
for model_case in LORA_MODELS_WITH_NONE:
65-
prompts = PROMPTS
57+
58+
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
59+
for model_case in model_cases:
6660
for torch_dtype in TORCH_DTYPES:
67-
run_lora_test_by_batch(
68-
prompts,
69-
model_case,
70-
torch_dtype,
71-
max_new_tokens=32,
72-
backend="triton",
73-
test_tag="test_lora_batch_with_none",
61+
max_new_tokens = 32
62+
backend = "triton"
63+
base_path = model_case.base
64+
lora_adapter_paths = [a.name for a in model_case.adaptors]
65+
assert len(lora_adapter_paths) >= 2
66+
67+
batches = [
68+
(
69+
[
70+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
71+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
72+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
73+
],
74+
[
75+
None,
76+
lora_adapter_paths[0],
77+
lora_adapter_paths[1],
78+
],
79+
),
80+
(
81+
[
82+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
83+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
84+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
85+
],
86+
[
87+
lora_adapter_paths[0],
88+
None,
89+
lora_adapter_paths[1],
90+
],
91+
),
92+
(
93+
[
94+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
95+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
96+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
97+
],
98+
[lora_adapter_paths[0], lora_adapter_paths[1], None],
99+
),
100+
(
101+
[
102+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
103+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
104+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
105+
],
106+
[None, lora_adapter_paths[1], None],
107+
),
108+
(
109+
[
110+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
111+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
112+
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
113+
],
114+
[None, None, None],
115+
),
116+
]
117+
118+
print(
119+
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
120+
)
121+
122+
# Initialize runners
123+
srt_runner = SRTRunner(
124+
base_path,
125+
torch_dtype=torch_dtype,
126+
model_type="generation",
127+
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
128+
max_loras_per_batch=len(lora_adapter_paths) + 1,
129+
lora_backend=backend,
130+
disable_radix_cache=True,
74131
)
132+
hf_runner = HFRunner(
133+
base_path, torch_dtype=torch_dtype, model_type="generation"
134+
)
135+
136+
with srt_runner, hf_runner:
137+
for i, (prompts, lora_paths) in enumerate(batches):
138+
print(
139+
f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}"
140+
)
141+
142+
srt_outputs = srt_runner.batch_forward(
143+
prompts,
144+
max_new_tokens=max_new_tokens,
145+
lora_paths=lora_paths,
146+
)
147+
148+
hf_outputs = hf_runner.forward(
149+
prompts,
150+
max_new_tokens=max_new_tokens,
151+
lora_paths=lora_paths,
152+
)
153+
154+
print("SRT outputs:", [s for s in srt_outputs.output_strs])
155+
print("HF outputs:", [s for s in hf_outputs.output_strs])
156+
157+
for srt_out, hf_out in zip(
158+
srt_outputs.output_strs, hf_outputs.output_strs
159+
):
160+
srt_str = srt_out.strip()
161+
hf_str = hf_out.strip()
162+
rouge_tol = model_case.rouge_l_tolerance
163+
rouge_score = calculate_rouge_l([srt_str], [hf_str])[0]
164+
if rouge_score < rouge_tol:
165+
raise AssertionError(
166+
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
167+
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
168+
)
169+
170+
print(f"--- Batch {i+1} Comparison Passed --- ")
171+
172+
def test_ci_lora_models(self):
173+
self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS)
174+
175+
def test_all_lora_models(self):
176+
if is_in_ci():
177+
return
178+
179+
filtered_models = []
180+
for model_case in ALL_OTHER_MULTI_LORA_MODELS:
181+
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
182+
continue
183+
filtered_models.append(model_case)
184+
185+
self._run_lora_multiple_batch_on_model_cases(filtered_models)
75186

76187

77188
if __name__ == "__main__":

test/srt/models/lora/test_multi_lora_backend.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,50 +18,16 @@
1818
from typing import List
1919

2020
from utils import (
21+
ALL_OTHER_MULTI_LORA_MODELS,
2122
BACKENDS,
23+
CI_MULTI_LORA_MODELS,
2224
TORCH_DTYPES,
23-
LoRAAdaptor,
2425
LoRAModelCase,
2526
run_lora_test_one_by_one,
2627
)
2728

2829
from sglang.test.test_utils import CustomTestCase, is_in_ci
2930

30-
CI_MULTI_LORA_MODELS = [
31-
# multi-rank case
32-
LoRAModelCase(
33-
base="meta-llama/Llama-2-7b-hf",
34-
adaptors=[
35-
LoRAAdaptor(
36-
name="winddude/wizardLM-LlaMA-LoRA-7B",
37-
prefill_tolerance=1e-1,
38-
),
39-
LoRAAdaptor(
40-
name="RuterNorway/Llama-2-7b-chat-norwegian-LoRa",
41-
prefill_tolerance=3e-1,
42-
),
43-
],
44-
max_loras_per_batch=2,
45-
),
46-
]
47-
48-
ALL_OTHER_MULTI_LORA_MODELS = [
49-
LoRAModelCase(
50-
base="meta-llama/Llama-3.1-8B-Instruct",
51-
adaptors=[
52-
LoRAAdaptor(
53-
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
54-
prefill_tolerance=1e-1,
55-
),
56-
LoRAAdaptor(
57-
name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
58-
prefill_tolerance=1e-1,
59-
),
60-
],
61-
max_loras_per_batch=2,
62-
),
63-
]
64-
6531
# All prompts are used at once in a batch.
6632
PROMPTS = [
6733
"AI is a field of computer science focused on",

test/srt/models/lora/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,41 @@ def __post_init__(self):
9393
),
9494
]
9595

96+
CI_MULTI_LORA_MODELS = [
97+
# multi-rank case
98+
LoRAModelCase(
99+
base="meta-llama/Llama-2-7b-hf",
100+
adaptors=[
101+
LoRAAdaptor(
102+
name="winddude/wizardLM-LlaMA-LoRA-7B",
103+
prefill_tolerance=1e-1,
104+
),
105+
LoRAAdaptor(
106+
name="RuterNorway/Llama-2-7b-chat-norwegian-LoRa",
107+
prefill_tolerance=3e-1,
108+
),
109+
],
110+
max_loras_per_batch=2,
111+
),
112+
]
113+
114+
ALL_OTHER_MULTI_LORA_MODELS = [
115+
LoRAModelCase(
116+
base="meta-llama/Llama-3.1-8B-Instruct",
117+
adaptors=[
118+
LoRAAdaptor(
119+
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
120+
prefill_tolerance=1e-1,
121+
),
122+
LoRAAdaptor(
123+
name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
124+
prefill_tolerance=1e-1,
125+
),
126+
],
127+
max_loras_per_batch=2,
128+
),
129+
]
130+
96131

97132
def run_lora_test_one_by_one(
98133
prompts: List[str],

0 commit comments

Comments
 (0)