Skip to content

Commit 9fdc6d6

Browse files
Fix the lora adapter when lora path is none (#4799)
Co-authored-by: Beichen Ma <[email protected]>
1 parent 42a45df commit 9fdc6d6

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

python/sglang/srt/lora/lora_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,6 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
133133
assert len(cur_uids) <= self.max_loras_per_batch
134134
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
135135

136-
# FIXME: Handle lora uid with None more safely
137-
if cur_uids == set([None]):
138-
return
139-
140136
# set up batch info shared by all lora moruldes
141137
bs = forward_batch.batch_size
142138
seg_lens = (

python/sglang/srt/lora/mem_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def load_lora_weight_to_buffer(
163163
if uid is None:
164164
for i in range(self.num_layer):
165165
for k in self.A_buffer.keys():
166-
self.A_buffer[k][i][buffer_id] *= 0
166+
self.A_buffer[k][i][buffer_id] = 0
167167
return
168168

169169
assert lora_adapter is not None

test/srt/models/lora/test_lora.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
9696
srt_outputs = srt_runner.forward(
9797
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
9898
)
99+
srt_outputs_lora_path_none = srt_runner.forward(
100+
prompts,
101+
max_new_tokens=max_new_tokens,
102+
lora_paths=[None] * len(prompts),
103+
)
99104

100105
with HFRunner(
101106
base_path, torch_dtype=torch_dtype, model_type="generation"
@@ -169,18 +174,20 @@ def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
169174
print(f"{srt_outputs.output_strs=}")
170175
print(f"{hf_no_lora_outputs.output_strs=}")
171176
print(f"{srt_no_lora_outputs.output_strs=}")
177+
print(f"{srt_outputs_lora_path_none.output_strs=}")
172178
for i in range(len(prompts)):
173179
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
174180
srt_outputs.output_strs[i].strip(" "),
175181
hf_outputs.output_strs[i],
176182
)
177-
# assert (
178-
# srt_no_lora_outputs.output_strs[i].strip(" ")
179-
# == hf_no_lora_outputs.output_strs[i]
180-
# ), (
181-
# srt_no_lora_outputs.output_strs[i].strip(" "),
182-
# hf_no_lora_outputs.output_strs[i],
183-
# )
183+
assert (
184+
srt_no_lora_outputs.output_strs[i].strip(" ")
185+
== hf_no_lora_outputs.output_strs[i]
186+
), (
187+
srt_no_lora_outputs.output_strs[i].strip(" "),
188+
hf_no_lora_outputs.output_strs[i],
189+
)
190+
assert srt_outputs_lora_path_none == srt_no_lora_outputs
184191

185192
def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
186193
print("=================== testing serving =======================")
@@ -257,7 +264,7 @@ def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens
257264
srt_no_lora_logprobs = torch.Tensor(
258265
srt_no_lora_outputs.top_input_logprobs[i]
259266
)
260-
srt_logprobs = torch.uensor(srt_outputs.top_input_logprobs[i])
267+
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
261268
print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs)))
262269

263270
print(f"{srt_no_lora_outputs.output_strs=}")
@@ -280,7 +287,7 @@ def test_all(self):
280287
tp_size = 1
281288
max_new_tokens = 32
282289
self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
283-
# self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
290+
self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
284291
# self.base_inference(
285292
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
286293
# )

0 commit comments

Comments
 (0)