diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 37760407bf..27543a734c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1472,8 +1472,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + self_attn_w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn_w_vc = w_vc.contiguous().transpose(1, 2) + if self_attn.w_kc is not None: + self_attn.w_kc.copy_(self_attn_w_kc) + self_attn.w_vc.copy_(self_attn_w_vc) + else: + self_attn.w_kc = self_attn_w_kc + self_attn.w_vc = self_attn_w_vc if ( hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None