|
16 | 16 | import json
|
17 | 17 | import logging
|
18 | 18 | import os
|
| 19 | +from contextlib import contextmanager |
19 | 20 | from typing import Any, Dict, List, Optional, Tuple
|
20 | 21 |
|
21 | 22 | import torch
|
@@ -59,7 +60,10 @@ def deep_gemm_fp8_fp8_bf16_nt(
|
59 | 60 | Bs: torch.Tensor,
|
60 | 61 | C: torch.Tensor,
|
61 | 62 | ) -> None:
|
62 |
| - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) |
| 63 | + M, K = A.shape |
| 64 | + N, _ = B.shape |
| 65 | + with _log_jit_build(M, N, K): |
| 66 | + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) |
63 | 67 |
|
64 | 68 | def deep_gemm_fp8_fp8_bf16_nt_fake(
|
65 | 69 | A: torch.Tensor,
|
@@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs(
|
708 | 712 | return None
|
709 | 713 |
|
710 | 714 |
|
| 715 | +@contextmanager |
| 716 | +def _log_jit_build(M: int, N: int, K: int): |
| 717 | + from deep_gemm.jit.runtime import RuntimeCache |
| 718 | + |
| 719 | + origin_func = RuntimeCache.__getitem__ |
| 720 | + |
| 721 | + def __patched_func(self, *args, **kwargs): |
| 722 | + ret = origin_func(self, *args, **kwargs) |
| 723 | + if ret is None: |
| 724 | + logger.warning( |
| 725 | + f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait." |
| 726 | + ) |
| 727 | + return ret |
| 728 | + |
| 729 | + RuntimeCache.__getitem__ = __patched_func |
| 730 | + yield |
| 731 | + RuntimeCache.__getitem__ = origin_func |
| 732 | + |
| 733 | + |
711 | 734 | def w8a8_block_fp8_matmul(
|
712 | 735 | A: torch.Tensor,
|
713 | 736 | B: torch.Tensor,
|
@@ -782,7 +805,8 @@ def grid(META):
|
782 | 805 | if supports_custom_op():
|
783 | 806 | torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
784 | 807 | else:
|
785 |
| - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) |
| 808 | + with _log_jit_build(M, N, K): |
| 809 | + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) |
786 | 810 | else:
|
787 | 811 | kernel = (
|
788 | 812 | _w8a8_block_fp8_matmul_unrolledx4
|
|
0 commit comments