Skip to content

Commit 07f23dc

Browse files
lifuhuangxwu-intel
authored andcommitted
Fix incorrect LoRA weight loading for fused gate_up_proj (sgl-project#6734)
1 parent b1b084d commit 07f23dc

File tree

4 files changed

+29
-14
lines changed

4 files changed

+29
-14
lines changed

python/sglang/srt/conversation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,8 @@ def generate_chat_conv(
680680
register_conv_template(
681681
Conversation(
682682
name="phi-4-mm",
683-
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
684-
system_template="<|system|>{system_message}<|end|>",
683+
system_message="",
684+
system_template="{system_message}",
685685
roles=("<|user|>", "<|assistant|>"),
686686
sep_style=SeparatorStyle.NO_COLON_SINGLE,
687687
sep="<|end|>",

python/sglang/srt/lora/lora.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,12 @@ def normalize_gate_up_proj(
209209
gate_up_name = weight_name
210210
if "lora_A" in weight_name:
211211
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
212-
# else: "lora_B" is already stacked, no operations is needed.
212+
else:
213+
output_dim = weights[gate_up_name].shape[0] // 2
214+
weights[gate_up_name] = torch.stack(
215+
[
216+
weights[gate_up_name][:output_dim, :],
217+
weights[gate_up_name][output_dim:, :],
218+
],
219+
dim=0,
220+
)

python/sglang/srt/models/idefics2.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,30 @@ def get_input_embeddings(self) -> nn.Embedding:
296296
def compute_cu_seqlens(
297297
self,
298298
tgt_sizes: Optional[torch.Tensor] = None,
299-
atch_attention_mask: Optional[torch.BoolTensor] = None,
299+
input_embeds: Optional[torch.Tensor] = None,
300300
) -> torch.Tensor:
301301
# shape: (batch_size,)
302302
if tgt_sizes is not None:
303-
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
303+
seqlen = tgt_sizes[:, 0] * tgt_sizes[:, 1]
304+
elif input_embeds is not None:
305+
seqlen = torch.full(
306+
size=(input_embeds.shape[0],),
307+
fill_value=input_embeds.shape[1],
308+
dtype=torch.int32,
309+
device=input_embeds.device,
310+
)
304311
else:
305-
patch_len = atch_attention_mask[:, :, 0].sum(dim=1) * atch_attention_mask[
306-
:, 0, :
307-
].sum(dim=1)
312+
raise ValueError(
313+
"Either `tgt_sizes` or `input_embeds` must be provided to compute cu_seqlens."
314+
)
308315

309316
cu_seqlens = torch.cat(
310317
[
311-
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
312-
torch.cumsum(patch_len, dim=0, dtype=torch.int32),
318+
torch.tensor([0], device=seqlen.device, dtype=torch.int32),
319+
torch.cumsum(seqlen, dim=0, dtype=torch.int32),
313320
],
314321
dim=0,
315-
).to(patch_len.device)
322+
).to(seqlen.device)
316323
return cu_seqlens
317324

318325
def forward(
@@ -326,7 +333,7 @@ def forward(
326333
patch_attention_mask=patch_attention_mask,
327334
tgt_sizes=tgt_sizes,
328335
)
329-
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, patch_attention_mask)
336+
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, hidden_states)
330337
encoder_outputs = self.encoder(
331338
hidden_states,
332339
cu_seqlens=cu_seqlens,

python/sglang/srt/models/phi4mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,8 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
451451
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
452452
return pattern.pad_input_tokens(input_ids, mm_inputs)
453453

454-
def should_apply_lora(self, module_name: str) -> Optional[str]:
455-
return self.lora_pattern.match(module_name)
454+
def should_apply_lora(self, module_name: str) -> bool:
455+
return bool(self.lora_pattern.match(module_name))
456456

457457
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
458458
stacked_params_mapping = [

0 commit comments

Comments
 (0)