|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +""" |
| 4 | +Support attention backend for Cutlass MLA. |
| 5 | +
|
| 6 | +""" |
| 7 | + |
| 8 | +from dataclasses import dataclass |
| 9 | +from typing import TYPE_CHECKING, Optional, Union |
| 10 | + |
| 11 | +import torch |
| 12 | +import triton |
| 13 | + |
| 14 | +from sglang.global_config import global_config |
| 15 | +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend |
| 16 | +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend |
| 17 | +from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton |
| 18 | +from sglang.srt.layers.dp_attention import get_attention_tp_size |
| 19 | +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode |
| 20 | +from sglang.srt.utils import is_cuda |
| 21 | + |
| 22 | +if TYPE_CHECKING: |
| 23 | + from sglang.srt.layers.radix_attention import RadixAttention |
| 24 | + from sglang.srt.model_executor.model_runner import ModelRunner |
| 25 | + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput |
| 26 | + from sglang.srt.speculative.spec_info import SpecInfo |
| 27 | + |
| 28 | +_is_cuda = is_cuda() |
| 29 | +if _is_cuda: |
| 30 | + from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size |
| 31 | + |
| 32 | + |
| 33 | +# Cutlass MLA only supports pagesize=128 |
| 34 | +PAGE_SIZE = 128 |
| 35 | + |
| 36 | + |
| 37 | +@dataclass |
| 38 | +class CutlassMLADecodeMetadata: |
| 39 | + workspace: Optional[torch.Tensor] = None |
| 40 | + block_kv_indices: Optional[torch.Tensor] = None |
| 41 | + |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + workspace: Optional[torch.Tensor] = None, |
| 45 | + block_kv_indices: Optional[torch.Tensor] = None, |
| 46 | + ): |
| 47 | + self.workspace = workspace |
| 48 | + self.block_kv_indices = block_kv_indices |
| 49 | + |
| 50 | + |
| 51 | +class CutlassMLABackend(FlashInferMLAAttnBackend): |
| 52 | + """Cutlass attention kernels.""" |
| 53 | + |
| 54 | + def __init__( |
| 55 | + self, |
| 56 | + model_runner: ModelRunner, |
| 57 | + skip_prefill: bool = False, |
| 58 | + kv_indptr_buf: Optional[torch.Tensor] = None, |
| 59 | + kv_last_page_len_buf: Optional[torch.Tensor] = None, |
| 60 | + ): |
| 61 | + super().__init__( |
| 62 | + model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf |
| 63 | + ) |
| 64 | + |
| 65 | + self.num_q_heads = ( |
| 66 | + model_runner.model_config.num_attention_heads // get_attention_tp_size() |
| 67 | + ) |
| 68 | + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( |
| 69 | + get_attention_tp_size() |
| 70 | + ) |
| 71 | + self.req_to_token = model_runner.req_to_token_pool.req_to_token |
| 72 | + self.num_local_heads = ( |
| 73 | + model_runner.model_config.num_attention_heads // get_attention_tp_size() |
| 74 | + ) |
| 75 | + self.forward_metadata: Union[CutlassMLADecodeMetadata] = None |
| 76 | + self.kv_lora_rank = model_runner.model_config.kv_lora_rank |
| 77 | + self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim |
| 78 | + self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim |
| 79 | + self.v_head_dim = model_runner.model_config.v_head_dim |
| 80 | + self.scaling = model_runner.model_config.scaling |
| 81 | + self.data_type = model_runner.kv_cache_dtype |
| 82 | + self.q_data_type = model_runner.dtype |
| 83 | + self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim |
| 84 | + |
| 85 | + def init_forward_metadata(self, forward_batch: ForwardBatch): |
| 86 | + |
| 87 | + bs = forward_batch.batch_size |
| 88 | + spec_info = forward_batch.spec_info |
| 89 | + if forward_batch.forward_mode.is_decode_or_idle(): |
| 90 | + if spec_info is None: |
| 91 | + max_seqlen_pad = triton.cdiv( |
| 92 | + forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE |
| 93 | + ) |
| 94 | + block_kv_indices = torch.full( |
| 95 | + (bs, max_seqlen_pad), |
| 96 | + -1, |
| 97 | + dtype=torch.int32, |
| 98 | + device=forward_batch.seq_lens.device, |
| 99 | + ) |
| 100 | + create_flashmla_kv_indices_triton[(bs,)]( |
| 101 | + self.req_to_token, |
| 102 | + forward_batch.req_pool_indices, |
| 103 | + forward_batch.seq_lens, |
| 104 | + None, |
| 105 | + block_kv_indices, |
| 106 | + self.req_to_token.stride(0), |
| 107 | + max_seqlen_pad, |
| 108 | + PAGE_SIZE, |
| 109 | + ) |
| 110 | + workspace_size = cutlass_mla_get_workspace_size( |
| 111 | + max_seqlen_pad * PAGE_SIZE, bs |
| 112 | + ) |
| 113 | + workspace = torch.empty( |
| 114 | + workspace_size, device="cuda", dtype=torch.uint8 |
| 115 | + ) |
| 116 | + self.forward_metadata = CutlassMLADecodeMetadata( |
| 117 | + workspace, |
| 118 | + block_kv_indices, |
| 119 | + ) |
| 120 | + else: |
| 121 | + super().init_forward_metadata(forward_batch) |
| 122 | + else: |
| 123 | + super().init_forward_metadata(forward_batch) |
| 124 | + |
| 125 | + def init_cuda_graph_state( |
| 126 | + self, |
| 127 | + max_bs: int, |
| 128 | + block_kv_indices: Optional[torch.Tensor] = None, |
| 129 | + ): |
| 130 | + if block_kv_indices is None: |
| 131 | + cuda_graph_kv_indices = torch.full( |
| 132 | + (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE), |
| 133 | + 1, |
| 134 | + dtype=torch.int32, |
| 135 | + device="cuda", |
| 136 | + ) |
| 137 | + else: |
| 138 | + cuda_graph_kv_indices = block_kv_indices |
| 139 | + |
| 140 | + workspace_size = cutlass_mla_get_workspace_size( |
| 141 | + cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs |
| 142 | + ) |
| 143 | + self.cuda_graph_mla_workspace = torch.empty( |
| 144 | + workspace_size, device="cuda", dtype=torch.uint8 |
| 145 | + ) |
| 146 | + self.cuda_graph_kv_indices = cuda_graph_kv_indices |
| 147 | + |
| 148 | + def init_forward_metadata_capture_cuda_graph( |
| 149 | + self, |
| 150 | + bs: int, |
| 151 | + num_tokens: int, |
| 152 | + req_pool_indices: torch.Tensor, |
| 153 | + seq_lens: torch.Tensor, |
| 154 | + encoder_lens: Optional[torch.Tensor], |
| 155 | + forward_mode: ForwardMode, |
| 156 | + spec_info: Optional[SpecInfo], |
| 157 | + ): |
| 158 | + if forward_mode.is_decode_or_idle(): |
| 159 | + if spec_info is None: |
| 160 | + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) |
| 161 | + |
| 162 | + create_flashmla_kv_indices_triton[(bs,)]( |
| 163 | + self.req_to_token, |
| 164 | + req_pool_indices, |
| 165 | + seq_lens, |
| 166 | + None, |
| 167 | + self.cuda_graph_kv_indices, |
| 168 | + self.req_to_token.stride(0), |
| 169 | + self.cuda_graph_kv_indices.stride(0), |
| 170 | + PAGE_SIZE, |
| 171 | + ) |
| 172 | + workspace_size = cutlass_mla_get_workspace_size( |
| 173 | + max_seqlen_pad * PAGE_SIZE, bs |
| 174 | + ) |
| 175 | + self.cuda_graph_mla_workspace = torch.empty( |
| 176 | + workspace_size, device="cuda", dtype=torch.uint8 |
| 177 | + ) |
| 178 | + self.forward_metadata = CutlassMLADecodeMetadata( |
| 179 | + self.cuda_graph_mla_workspace, |
| 180 | + self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], |
| 181 | + ) |
| 182 | + else: |
| 183 | + super().init_forward_metadata_capture_cuda_graph( |
| 184 | + bs, |
| 185 | + num_tokens, |
| 186 | + req_pool_indices, |
| 187 | + seq_lens, |
| 188 | + encoder_lens, |
| 189 | + forward_mode, |
| 190 | + spec_info, |
| 191 | + ) |
| 192 | + |
| 193 | + def init_forward_metadata_replay_cuda_graph( |
| 194 | + self, |
| 195 | + bs: int, |
| 196 | + req_pool_indices: torch.Tensor, |
| 197 | + seq_lens: torch.Tensor, |
| 198 | + seq_lens_sum: int, |
| 199 | + encoder_lens: Optional[torch.Tensor], |
| 200 | + forward_mode: ForwardMode, |
| 201 | + spec_info: Optional[SpecInfo], |
| 202 | + seq_lens_cpu: Optional[torch.Tensor], |
| 203 | + ): |
| 204 | + |
| 205 | + if forward_mode.is_decode_or_idle(): |
| 206 | + assert seq_lens_cpu is not None |
| 207 | + seq_lens = seq_lens[:bs] |
| 208 | + seq_lens_cpu = seq_lens_cpu[:bs] |
| 209 | + max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) |
| 210 | + create_flashmla_kv_indices_triton[(bs,)]( |
| 211 | + self.req_to_token, |
| 212 | + req_pool_indices[:bs], |
| 213 | + seq_lens, |
| 214 | + None, |
| 215 | + self.cuda_graph_kv_indices, |
| 216 | + self.req_to_token.stride(0), |
| 217 | + self.cuda_graph_kv_indices.stride(0), |
| 218 | + PAGE_SIZE, |
| 219 | + ) |
| 220 | + workspace_size = cutlass_mla_get_workspace_size( |
| 221 | + max_seqlen_pad * PAGE_SIZE, bs |
| 222 | + ) |
| 223 | + self.cuda_graph_mla_workspace = torch.empty( |
| 224 | + workspace_size, device="cuda", dtype=torch.uint8 |
| 225 | + ) |
| 226 | + self.forward_metadata.workspace = self.cuda_graph_mla_workspace |
| 227 | + self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ |
| 228 | + :bs, :max_seqlen_pad |
| 229 | + ] |
| 230 | + else: |
| 231 | + super().init_forward_metadata_replay_cuda_graph( |
| 232 | + bs, |
| 233 | + req_pool_indices, |
| 234 | + seq_lens, |
| 235 | + seq_lens_sum, |
| 236 | + encoder_lens, |
| 237 | + forward_mode, |
| 238 | + spec_info, |
| 239 | + seq_lens_cpu, |
| 240 | + ) |
| 241 | + |
| 242 | + def get_cuda_graph_seq_len_fill_value(self): |
| 243 | + return 1 |
| 244 | + |
| 245 | + def forward_decode( |
| 246 | + self, |
| 247 | + q: torch.Tensor, |
| 248 | + k: torch.Tensor, |
| 249 | + v: torch.Tensor, |
| 250 | + layer: RadixAttention, |
| 251 | + forward_batch: ForwardBatch, |
| 252 | + save_kv_cache: bool = True, |
| 253 | + ): |
| 254 | + cache_loc = forward_batch.out_cache_loc |
| 255 | + |
| 256 | + if k is not None: |
| 257 | + assert v is not None |
| 258 | + if save_kv_cache: |
| 259 | + forward_batch.token_to_kv_pool.set_kv_buffer( |
| 260 | + layer, |
| 261 | + cache_loc, |
| 262 | + k, |
| 263 | + v, |
| 264 | + ) |
| 265 | + bs = forward_batch.batch_size |
| 266 | + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) |
| 267 | + |
| 268 | + reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) |
| 269 | + |
| 270 | + o = cutlass_mla_decode( |
| 271 | + q_nope_and_q_pe=reshape_q, |
| 272 | + kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim), |
| 273 | + seq_lens=forward_batch.seq_lens.to(torch.int32), |
| 274 | + page_table=self.forward_metadata.block_kv_indices, |
| 275 | + workspace=self.forward_metadata.workspace, |
| 276 | + ) |
| 277 | + |
| 278 | + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) |
0 commit comments