Skip to content

update deepgemm version #3606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def build(in_features: int, out_features: int, block_size: int = 128, bias: bool
logger.debug('build with DeepGemmLinearBlockedF8Impl')
return DeepGemmLinearBlockedF8Impl(in_features, out_features, block_size, dtype)
except: # noqa
logger.warning('Failed to import deep_gemm, LinearBlockedF8 fallback to triton implementation.')
return TritonLinearBlockedF8Impl(in_features, out_features, block_size, dtype)


Expand All @@ -89,6 +90,8 @@ def __init__(self, in_features: int, out_features: int, block_size: int, out_dty

def warmup(self, warmup_meta: WarmupMeta):
"""warmup."""
import random

from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout
device = 'cuda'
max_num_tokens = warmup_meta.max_num_tokens
Expand All @@ -100,7 +103,10 @@ def warmup(self, warmup_meta: WarmupMeta):
scale = torch.empty(((n + block_size - 1) // block_size, (k + block_size - 1) // block_size),
dtype=torch.float32,
device=device)
for m in range(alignment, range_end, alignment):
# shuffle ranges so ranks might compile different kernels concurrently.
ranges = list(range(alignment, range_end, alignment))
random.shuffle(ranges)
for m in ranges:
inputs = torch.empty(m, k, dtype=self.out_dtype, device=device)
input_quant, input_scale = quant_fp8_tma(inputs, self.block_size, dtype=weight.dtype)
deep_gemm_fp8(input_quant, input_scale, weight, scale, out_dtype=inputs.dtype)
Expand Down
6 changes: 4 additions & 2 deletions lmdeploy/pytorch/backends/cuda/warmup_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ def warmup(self, warmup_meta: WarmupMeta):
"""Warmup meta."""
if len(self._warmup_calls) == 0:
return

import random
logger.info('Warming up ops.')
for key, func in self._warmup_calls.items():
funcs = list(self._warmup_calls.values())
random.shuffle(funcs)
for func in funcs:
func(warmup_meta)


Expand Down
11 changes: 8 additions & 3 deletions lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,17 +285,22 @@ def grid(META):
@contextmanager
def _log_jit_build(M: int, N: int, K: int):
from deep_gemm.jit.runtime import RuntimeCache
origin_func = RuntimeCache.__getitem__

if hasattr(RuntimeCache, 'get'):
func_name = 'get'
else:
func_name = '__getitem__'
origin_func = getattr(RuntimeCache, func_name)

def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs)
if ret is None:
logger.warning(f'DeepGemm build <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please waiting.')
return ret

RuntimeCache.__getitem__ = __patched_func
setattr(RuntimeCache, func_name, __patched_func)
yield
RuntimeCache.__getitem__ = origin_func
setattr(RuntimeCache, func_name, origin_func)


def deep_gemm_fp8(A: Tensor,
Expand Down