Skip to content

Commit ac9dd2c

Browse files
Qiaolin-YuBeichen-Ma
authored andcommitted
Feat: support cuda graph for LoRA (sgl-project#4115)
Co-authored-by: Beichen Ma <[email protected]>
1 parent 7be88d1 commit ac9dd2c

File tree

13 files changed

+367
-56
lines changed

13 files changed

+367
-56
lines changed

benchmark/lora/launch_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def launch_server(args):
1919
for i in range(NUM_LORAS):
2020
lora_name = f"lora{i}"
2121
cmd += f"{lora_name}={lora_path} "
22-
cmd += f"--disable-radix --disable-cuda-graph "
22+
cmd += f"--disable-radix "
2323
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
2424
cmd += f"--max-running-requests {args.max_running_requests} "
2525
cmd += f"--lora-backend {args.lora_backend} "

docs/backend/lora.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
7878
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
7979
" --max-loras-per-batch 1 --lora-backend triton \\\n",
80-
" --disable-cuda-graph --disable-radix-cache\n",
80+
" --disable-radix-cache\n",
8181
"\"\"\"\n",
8282
")\n",
8383
"\n",
@@ -136,7 +136,7 @@
136136
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
137137
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
138138
" --max-loras-per-batch 2 --lora-backend triton \\\n",
139-
" --disable-cuda-graph --disable-radix-cache\n",
139+
" --disable-radix-cache\n",
140140
"\"\"\"\n",
141141
")\n",
142142
"\n",
@@ -182,7 +182,7 @@
182182
"source": [
183183
"## Future Works\n",
184184
"\n",
185-
"The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Currently Cuda graph and radix attention are not incompatible with LoRA and must be manually disabled. Other features, including Unified Paging, Cutlass backend, and dynamic loading/unloadingm, are still under development."
185+
"The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Currently radix attention is incompatible with LoRA and must be manually disabled. Other features, including Unified Paging, Cutlass backend, and dynamic loading/unloadingm, are still under development."
186186
]
187187
}
188188
],

docs/backend/server_arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
160160

161161
| Arguments | Description | Defaults |
162162
|----------|-------------|---------|
163-
| `lora_paths` | List of adapters to apply to your model. Each batch element uses the proper LoRA adapter. `cuda_graph` and `radix_attention` are not supported with this, so they must be disabled manually. See related [issues](https://github.com/sgl-project/sglang/issues/2929). | None |
163+
| `lora_paths` | List of adapters to apply to your model. Each batch element uses the proper LoRA adapter. `radix_attention` is not supported with this, so it must be disabled manually. See related [issues](https://github.com/sgl-project/sglang/issues/2929). | None |
164164
| `max_loras_per_batch` | Maximum number of LoRAs allowed in a running batch, including the base model. | `8` |
165165
| `lora_backend` | Backend used to run GEMM kernels for LoRA modules. Can be `triton` or `flashinfer`. | `triton` |
166166

python/sglang/srt/lora/layers.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,19 @@ def set_lora_info(
136136
self.set_lora = True
137137
self.A_buffer_gate_up = A_buffer
138138
if self.lora_backend.fuse_stacked_lora_b:
139-
# TODO: avoid using contiguous() in GPU.
140139
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
141-
self.B_buffer_gate_up = torch.cat(
142-
(B_buffer[0], B_buffer[1]), dim=-2
143-
).contiguous()
140+
if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
141+
self.B_buffer_gate_up = torch.empty(
142+
(
143+
B_buffer[0].shape[0],
144+
2 * B_buffer[0].shape[1],
145+
B_buffer[0].shape[2],
146+
),
147+
dtype=B_buffer[0].dtype,
148+
device=B_buffer[0].device,
149+
)
150+
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
151+
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
144152
else:
145153
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
146154

@@ -171,7 +179,7 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
171179

172180

173181
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
174-
def init__(
182+
def __init__(
175183
self,
176184
base_layer: QKVParallelLinear,
177185
lora_backend: BaseLoRABackend,
@@ -194,12 +202,30 @@ def set_lora_info(
194202
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
195203

196204
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
197-
self.B_buffer_qkv = torch.cat(
198-
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
199-
).contiguous()
205+
if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
206+
self.B_buffer_qkv = torch.empty(
207+
(
208+
B_buffer_q[0].shape[0],
209+
output_dim_q + 2 * output_dim_kv,
210+
B_buffer_q[0].shape[2],
211+
),
212+
dtype=B_buffer_q[0].dtype,
213+
device=B_buffer_q[0].device,
214+
)
215+
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
216+
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
217+
B_buffer_kv[0]
218+
)
219+
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
220+
B_buffer_kv[1]
221+
)
200222

201223
# Offsets of q/k/v in output dimension
202-
self.output_offset = torch.tensor(
224+
if not hasattr(self, "output_offset") or self.output_offset is None:
225+
self.output_offset = torch.empty(
226+
4, dtype=torch.int32, device=B_buffer_q.device
227+
)
228+
self.output_offset[:4] = torch.tensor(
203229
[
204230
0,
205231
output_dim_q,

python/sglang/srt/lora/lora_manager.py

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ def __init__(
7272
self.init_loras()
7373
self.init_lora_memory_pool()
7474

75+
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
76+
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
77+
with torch.device("cuda"):
78+
self.cuda_graph_batch_info = LoRABatchInfo(
79+
bs=self.max_bs_in_cuda_graph,
80+
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
81+
seg_indptr=torch.zeros(
82+
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
83+
),
84+
max_len=0,
85+
weight_indices=torch.zeros(
86+
self.max_bs_in_cuda_graph, dtype=torch.int32
87+
),
88+
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
89+
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
90+
)
91+
7592
def init_loras(self):
7693
# Config of each LoRA adapter
7794
self.configs: Dict[str, LoRAConfig] = {}
@@ -140,39 +157,73 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
140157
if cur_uids == set([None]):
141158
return
142159

143-
# set up batch info shared by all lora moruldes
160+
# set up batch info shared by all lora modules
144161
bs = forward_batch.batch_size
145-
seg_lens = (
146-
forward_batch.extend_seq_lens
147-
if forward_batch.forward_mode.is_extend()
148-
else torch.ones(bs, device=self.device)
149-
)
150-
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
151-
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
152-
max_len = int(torch.max(seg_lens))
153-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
154162

155-
lora_ranks = torch.empty(
156-
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
157-
)
158-
scalings = torch.empty(
159-
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
160-
)
161-
for i, lora_path in enumerate(forward_batch.lora_paths):
162-
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
163-
lora = self.loras[lora_path]
164-
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
165-
scalings[weight_indices[i]] = lora.scaling
166-
167-
batch_info = LoRABatchInfo(
168-
bs=bs,
169-
seg_lens=seg_lens,
170-
seg_indptr=seg_indptr,
171-
max_len=max_len,
172-
weight_indices=weight_indices,
173-
lora_ranks=lora_ranks,
174-
scalings=scalings,
175-
)
163+
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
164+
# Do in-place updates when CUDA graph is enabled. Note that
165+
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
166+
# will also use these preallocated buffers, no matter whether
167+
# the batch can use CUDA graph or not.
168+
self.cuda_graph_batch_info.bs = bs
169+
if forward_batch.forward_mode.is_extend():
170+
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
171+
forward_batch.extend_seq_lens
172+
)
173+
else:
174+
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
175+
torch.cumsum(
176+
self.cuda_graph_batch_info.seg_lens[:bs],
177+
dim=0,
178+
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
179+
)
180+
self.cuda_graph_batch_info.max_len = int(
181+
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
182+
)
183+
184+
for i, lora_path in enumerate(forward_batch.lora_paths):
185+
self.cuda_graph_batch_info.weight_indices[i] = (
186+
self.memory_pool.get_buffer_id(lora_path)
187+
)
188+
lora = self.loras[lora_path]
189+
self.cuda_graph_batch_info.lora_ranks[
190+
self.cuda_graph_batch_info.weight_indices[i]
191+
] = lora.config.hf_config["r"]
192+
self.cuda_graph_batch_info.scalings[
193+
self.cuda_graph_batch_info.weight_indices[i]
194+
] = lora.scaling
195+
batch_info = self.cuda_graph_batch_info
196+
else:
197+
seg_lens = (
198+
forward_batch.extend_seq_lens
199+
if forward_batch.forward_mode.is_extend()
200+
else torch.ones(bs, device=self.device)
201+
)
202+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
203+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
204+
max_len = int(torch.max(seg_lens))
205+
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
206+
207+
lora_ranks = torch.empty(
208+
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
209+
)
210+
scalings = torch.empty(
211+
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
212+
)
213+
for i, lora_path in enumerate(forward_batch.lora_paths):
214+
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
215+
lora = self.loras[lora_path]
216+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
217+
scalings[weight_indices[i]] = lora.scaling
218+
batch_info = LoRABatchInfo(
219+
bs=bs,
220+
seg_lens=seg_lens,
221+
seg_indptr=seg_indptr,
222+
max_len=max_len,
223+
weight_indices=weight_indices,
224+
lora_ranks=lora_ranks,
225+
scalings=scalings,
226+
)
176227
self.lora_backend.set_batch_info(batch_info)
177228

178229
# call set_lora_info for each lora modules

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ def __init__(self, model_runner: ModelRunner):
220220
if self.enable_torch_compile:
221221
set_torch_compile_config()
222222

223+
if self.model_runner.server_args.lora_paths is not None:
224+
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
225+
223226
# Graph inputs
224227
with torch.device("cuda"):
225228
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
@@ -403,6 +406,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
403406
self.capture_hidden_mode = (
404407
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
405408
)
409+
if self.model_runner.server_args.lora_paths is not None:
410+
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
411+
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
412+
# values if lora is enabled.
413+
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
414+
else:
415+
lora_paths = None
406416

407417
forward_batch = ForwardBatch(
408418
forward_mode=self.capture_forward_mode,
@@ -424,8 +434,12 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
424434
spec_algorithm=self.model_runner.spec_algorithm,
425435
spec_info=spec_info,
426436
capture_hidden_mode=self.capture_hidden_mode,
437+
lora_paths=lora_paths,
427438
)
428439

440+
if lora_paths is not None:
441+
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
442+
429443
# Attention backend
430444
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
431445
bs,

python/sglang/srt/server_args.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,6 @@ def check_server_args(self):
12421242
assert (
12431243
self.max_loras_per_batch > 0
12441244
# FIXME
1245-
and (self.lora_paths is None or self.disable_cuda_graph)
12461245
and (self.lora_paths is None or self.disable_radix_cache)
12471246
), "compatibility of lora and cuda graph and radix attention is in progress"
12481247
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"

test/srt/models/lora/test_lora_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
DEFAULT_PROMPTS,
2525
TORCH_DTYPES,
2626
LoRAModelCase,
27-
run_batch_lora_test,
27+
run_lora_test_one_by_one,
2828
)
2929

3030
from sglang.test.test_utils import CustomTestCase, is_in_ci
@@ -42,7 +42,7 @@ def _run_backend_on_model_cases(self, model_cases: List[LoRAModelCase]):
4242
)
4343
for torch_dtype in TORCH_DTYPES:
4444
for backend in BACKENDS:
45-
run_batch_lora_test(
45+
run_lora_test_one_by_one(
4646
prompts,
4747
model_case,
4848
torch_dtype,

0 commit comments

Comments
 (0)