Skip to content

Commit 70817a7

Browse files
Fridge003Ying1123
andauthored
[Feature] Define backends and add Triton backend for Lora (#3161)
Co-authored-by: Ying Sheng <[email protected]>
1 parent 7b5a374 commit 70817a7

File tree

18 files changed

+1129
-135
lines changed

18 files changed

+1129
-135
lines changed

benchmark/lora/launch_server.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import argparse
22
import os
33

4-
NUM_LORAS = 8
4+
NUM_LORAS = 4
55
LORA_PATH = {
6-
"base": "mistralai/Mistral-7B-Instruct-v0.3",
7-
"lora": "/home/ying/test_lora",
6+
"base": "meta-llama/Llama-2-7b-hf",
7+
"lora": "winddude/wizardLM-LlaMA-LoRA-7B",
88
}
99

1010

@@ -21,7 +21,8 @@ def launch_server(args):
2121
cmd += f"{lora_name}={lora_path} "
2222
cmd += f"--disable-radix --disable-cuda-graph "
2323
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
24-
cmd += f"--max-running-requests {args.max_running_requests}"
24+
cmd += f"--max-running-requests {args.max_running_requests} "
25+
cmd += f"--lora-backend {args.lora_backend}"
2526
print(cmd)
2627
os.system(cmd)
2728

@@ -42,6 +43,11 @@ def launch_server(args):
4243
type=int,
4344
default=8,
4445
)
46+
parser.add_argument(
47+
"--lora-backend",
48+
type=str,
49+
default="triton",
50+
)
4551
args = parser.parse_args()
4652

4753
launch_server(args)

benchmark/lora/lora_bench.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ async def benchmark(
183183
api_url=api_url,
184184
prompt_len=test_prompt_len,
185185
output_len=test_output_len,
186+
lora_name="dummy", # the lora_name argument will not be used
186187
extra_request_body=extra_request_body,
187188
)
188189
test_output = await request_func(request_func_input=test_input)
@@ -206,6 +207,7 @@ async def benchmark(
206207
api_url=api_url,
207208
prompt_len=prompt_len,
208209
output_len=output_len,
210+
lora_name="dummy",
209211
extra_request_body=extra_request_body,
210212
)
211213
tasks.append(
@@ -255,6 +257,9 @@ async def benchmark(
255257
"Output token throughput (tok/s):", metrics.output_throughput
256258
)
257259
)
260+
print(
261+
"{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput)
262+
)
258263
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
259264
print(
260265
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)

docs/backend/server_arguments.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ Please consult the documentation below to learn more about the parameters you ma
124124

125125
* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929).
126126
* `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model.
127+
* `lora_backend`: The backend of running GEMM kernels for Lora modules, can be one of `triton` or `flashinfer`. Defaults to be `triton`.
127128

128129
## Kernel backend
129130

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .base_backend import BaseLoraBackend
2+
from .flashinfer_backend import FlashInferLoraBackend
3+
from .triton_backend import TritonLoraBackend
4+
5+
__all__ = [
6+
"FlashInferLoraBackend",
7+
"TritonLoraBackend",
8+
]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import Tuple, Union
2+
3+
import torch
4+
5+
from sglang.srt.lora.lora import LoraBatchInfo
6+
7+
8+
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
9+
mapping = {
10+
"triton": True,
11+
"flashinfer": False,
12+
}
13+
return mapping.get(name, False)
14+
15+
16+
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
17+
mapping = {
18+
"triton": True,
19+
"flashinfer": False,
20+
}
21+
return mapping.get(name, False)
22+
23+
24+
class BaseLoraBackend:
25+
"""Base class for different Lora backends.
26+
Each backend has its own implementation of Lora kernels.
27+
28+
Args:
29+
name: name of backend
30+
batch_info: information of current batch for use
31+
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32+
and the operation of scaling and adding will be fused into kernel
33+
"""
34+
35+
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
36+
self.name = name
37+
self.batch_info = batch_info
38+
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
39+
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
40+
41+
def run_lora_a_sgemm(
42+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
43+
) -> torch.Tensor:
44+
"""Run segment Gemm of lora a modules with current backend.
45+
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
46+
47+
Args:
48+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
49+
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
50+
usually input_dim is much larger than r
51+
Returns:
52+
result with shape (s, r)
53+
"""
54+
pass
55+
56+
def run_lora_b_sgemm(
57+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
58+
) -> torch.Tensor:
59+
"""Run segment Gemm of lora b modules with current backend.
60+
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
61+
62+
Args:
63+
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
64+
weights: a set of lora weights with shape (num_lora, output_dim, r)
65+
usually output_dim is much larger than r
66+
Returns:
67+
result with shape (s, output_dim)
68+
"""
69+
pass
70+
71+
def run_qkv_lora(
72+
self,
73+
x: torch.Tensor,
74+
qkv_lora_a: torch.Tensor,
75+
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
76+
*args,
77+
**kwargs
78+
) -> torch.Tensor:
79+
"""Run the lora pass for QKV Layer.
80+
81+
Args:
82+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
83+
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
84+
qkv_lora_b: lora_b module for qkv.
85+
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
86+
If passed in as a tuple of two tensors containing:
87+
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
88+
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
89+
Returns:
90+
result with shape (s, output_dim_q + 2 * output_dim_kv)
91+
"""
92+
pass
93+
94+
def set_batch_info(self, batch_info: LoraBatchInfo):
95+
self.batch_info = batch_info
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import Tuple
2+
3+
import torch
4+
from flashinfer import SegmentGEMMWrapper
5+
6+
from sglang.srt.lora.backend import BaseLoraBackend
7+
from sglang.srt.lora.lora import LoraBatchInfo
8+
9+
10+
class FlashInferLoraBackend(BaseLoraBackend):
11+
12+
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
13+
super().__init__(name, batch_info)
14+
15+
# Set up SGemm Wrapper from flashinfer
16+
# FIXME wait for flashinfer segment gemm update
17+
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
18+
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
19+
20+
def run_lora_a_sgemm(
21+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
22+
) -> torch.Tensor:
23+
24+
return self.segment_gemm.run(
25+
x=x,
26+
weights=weights,
27+
batch_size=self.batch_info.bs,
28+
weight_column_major=True,
29+
seg_indptr=self.batch_info.seg_indptr,
30+
weight_indices=self.batch_info.weight_indices,
31+
)
32+
33+
def run_lora_b_sgemm(
34+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
35+
) -> torch.Tensor:
36+
37+
return self.segment_gemm.run(
38+
x=x,
39+
weights=weights,
40+
batch_size=self.batch_info.bs,
41+
weight_column_major=True,
42+
seg_indptr=self.batch_info.seg_indptr,
43+
weight_indices=self.batch_info.weight_indices,
44+
)
45+
46+
def run_qkv_lora(
47+
self,
48+
x: torch.Tensor,
49+
qkv_lora_a: torch.Tensor,
50+
qkv_lora_b: Tuple[torch.Tensor],
51+
*args,
52+
**kwargs,
53+
) -> torch.Tensor:
54+
55+
# Shape of lora_a_output: (s, 3 * r)
56+
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
57+
58+
q_lora_b, kv_lora_b = qkv_lora_b
59+
lora_rank = kv_lora_b.shape[-1]
60+
output_dim_q = q_lora_b.shape[-2]
61+
output_dim_kv = kv_lora_b.shape[-2]
62+
lora_output = torch.empty(
63+
(x.shape[0], output_dim_q + 2 * output_dim_kv),
64+
device=x.device,
65+
dtype=x.dtype,
66+
)
67+
68+
# q
69+
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
70+
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
71+
)
72+
73+
# kv
74+
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
75+
self.run_lora_b_sgemm(
76+
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
77+
weights=kv_lora_b[0],
78+
)
79+
)
80+
81+
lora_output[
82+
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
83+
] = self.run_lora_b_sgemm(
84+
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
85+
weights=kv_lora_b[1],
86+
)
87+
88+
return lora_output
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
3+
from sglang.srt.lora.backend import BaseLoraBackend
4+
from sglang.srt.lora.lora import LoraBatchInfo
5+
from sglang.srt.lora.triton_ops import (
6+
qkv_lora_b_fwd,
7+
sgemm_lora_a_fwd,
8+
sgemm_lora_b_fwd,
9+
)
10+
11+
12+
class TritonLoraBackend(BaseLoraBackend):
13+
14+
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
15+
super().__init__(name, batch_info)
16+
17+
def run_lora_a_sgemm(
18+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
19+
) -> torch.Tensor:
20+
return sgemm_lora_a_fwd(x, weights, self.batch_info)
21+
22+
def run_lora_b_sgemm(
23+
self,
24+
x: torch.Tensor,
25+
weights: torch.Tensor,
26+
base_output: torch.Tensor = None,
27+
scaling: float = 1.0,
28+
*args,
29+
**kwargs
30+
) -> torch.Tensor:
31+
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
32+
33+
def run_qkv_lora(
34+
self,
35+
x: torch.Tensor,
36+
qkv_lora_a: torch.Tensor,
37+
qkv_lora_b: torch.Tensor,
38+
output_offset: torch.Tensor,
39+
max_qkv_out_dim: int,
40+
base_output: torch.Tensor = None,
41+
scaling: float = 1.0,
42+
*args,
43+
**kwargs
44+
) -> torch.Tensor:
45+
46+
# x: (s, input_dim)
47+
# qkv_lora_a: (num_lora, 3 * r, input_dim)
48+
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
49+
assert isinstance(qkv_lora_b, torch.Tensor)
50+
51+
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
52+
lora_output = qkv_lora_b_fwd(
53+
lora_a_output,
54+
qkv_lora_b,
55+
self.batch_info,
56+
output_offset,
57+
max_qkv_out_dim,
58+
base_output,
59+
scaling,
60+
)
61+
return lora_output

0 commit comments

Comments
 (0)