Skip to content

[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. #17011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
0] == 1 and qinput.shape[1] % 16 == 0:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
else:
Expand Down Expand Up @@ -371,7 +372,7 @@ def apply(

return w8a8_scaled_mm_func(qinput=qinput,
weight=weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
def rocm_unquantized_gemm(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
from vllm.platforms.rocm import on_mi250_mi300
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)

Expand All @@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor,
m = weight.shape[0]
cu_count = current_platform.get_cu_count()

if m > 8 and n < 4:
if m > 8 and 0 < n < 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
out = ops.LLMM1(weight, x_view, out, 4)
out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -40,7 +41,7 @@ def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
return dispatch_unquantized_gemm()(x, layer.weight, bias)

def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
Expand Down
16 changes: 8 additions & 8 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,22 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id


def on_mi250_mi300() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])


@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int,
sliding_window: int) -> bool:

GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_NAVI = "gfx1" in GPU_ARCH
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])

# rocm custom page attention not support on navi (gfx1*)
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
return (ON_MI250_MI300 and not ON_NAVI
and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
Expand Down