Skip to content

Commit 11b23ae

Browse files
ispobocksaienduri
andauthored
Remove extra copy in deepseek forward absorb (#5578)
Co-authored-by: saienduri <[email protected]>
1 parent b9c87e7 commit 11b23ae

File tree

3 files changed

+18
-21
lines changed

3 files changed

+18
-21
lines changed

.github/workflows/pr-test-amd.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ jobs:
3838
else
3939
DEVICE_FLAG="--device /dev/dri"
4040
fi
41-
docker pull lmsysorg/sglang:v0.4.5-rocm630
41+
docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
4242
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
4343
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
4444
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
4545
-w /sglang-checkout --name ci_sglang \
46-
lmsysorg/sglang:v0.4.5-rocm630
46+
lmsysorg/sglang:v0.4.5.post2-rocm630
4747
4848
- name: Install dependencies
4949
run: |
@@ -82,12 +82,12 @@ jobs:
8282
else
8383
DEVICE_FLAG="--device /dev/dri"
8484
fi
85-
docker pull lmsysorg/sglang:v0.4.5-rocm630
85+
docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
8686
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
8787
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
8888
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
8989
-w /sglang-checkout --name ci_sglang \
90-
lmsysorg/sglang:v0.4.5-rocm630
90+
lmsysorg/sglang:v0.4.5.post2-rocm630
9191
9292
- name: Install dependencies
9393
run: |
@@ -120,12 +120,12 @@ jobs:
120120
else
121121
DEVICE_FLAG="--device /dev/dri"
122122
fi
123-
docker pull lmsysorg/sglang:v0.4.5-rocm630
123+
docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
124124
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
125125
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
126126
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
127127
-w /sglang-checkout --name ci_sglang \
128-
lmsysorg/sglang:v0.4.5-rocm630
128+
lmsysorg/sglang:v0.4.5.post2-rocm630
129129
130130
- name: Install dependencies
131131
run: |
@@ -149,7 +149,7 @@ jobs:
149149
finish:
150150
if: always()
151151
needs: [
152-
accuracy-test-1-gpu-amd, mla-test-1-gpu-amd
152+
accuracy-test-1-gpu-amd, mla-test-1-gpu-amd, bench-test-2-gpu-amd
153153
]
154154
runs-on: ubuntu-latest
155155
steps:

python/sglang/srt/layers/rotary_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ def forward_native(
665665
offsets: Optional[torch.Tensor] = None,
666666
) -> Tuple[torch.Tensor, torch.Tensor]:
667667
"""PyTorch-native implementation equivalent to forward()."""
668+
dtype = query.dtype
668669
query_rot = query[..., : self.rotary_dim]
669670
key_rot = key[..., : self.rotary_dim]
670671
if self.rotary_dim < self.head_size:
@@ -695,7 +696,7 @@ def forward_native(
695696
else:
696697
query = query_rot
697698
key = key_rot
698-
return query, key
699+
return query.to(dtype), key.to(dtype)
699700

700701

701702
class Llama3RotaryEmbedding(RotaryEmbedding):

python/sglang/srt/models/deepseek_v2.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -682,10 +682,6 @@ def forward_absorb(
682682
forward_batch: ForwardBatch,
683683
zero_allocator: BumpAllocator,
684684
) -> torch.Tensor:
685-
q_len = hidden_states.shape[0]
686-
q_input = hidden_states.new_empty(
687-
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
688-
)
689685
if self.q_lora_rank is not None:
690686
q = self.q_a_proj(hidden_states)[0]
691687
q = self.q_a_layernorm(q)
@@ -729,20 +725,20 @@ def forward_absorb(
729725
)
730726
else:
731727
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
732-
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
728+
729+
q_nope_out = q_nope_out.transpose(0, 1)
733730

734731
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
735-
v_input = latent_cache[..., : self.kv_lora_rank]
736-
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
737-
k_input = latent_cache.unsqueeze(1)
738-
k_input[..., : self.kv_lora_rank] = v_input
739-
k_pe = k_input[..., self.kv_lora_rank :]
732+
k_nope = latent_cache[..., : self.kv_lora_rank]
733+
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
734+
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
740735

741736
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
742-
q_input[..., self.kv_lora_rank :] = q_pe
743-
k_input[..., self.kv_lora_rank :] = k_pe
744737

745-
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
738+
q = torch.cat([q_nope_out, q_pe], dim=-1)
739+
k = torch.cat([k_nope, k_pe], dim=-1)
740+
741+
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
746742
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
747743

748744
if self.use_deep_gemm_bmm:

0 commit comments

Comments
 (0)