Skip to content

Commit dbe41a7

Browse files
Add torchao quant (int4/int8/fp8) to llama models (sgl-project#1341)
Co-authored-by: Lianmin Zheng <[email protected]>
1 parent dce4e26 commit dbe41a7

File tree

10 files changed

+151
-12
lines changed

10 files changed

+151
-12
lines changed

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
[project.optional-dependencies]
2323
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
2424
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
25-
"torch", "uvicorn", "uvloop", "zmq",
25+
"torch", "torchao", "uvicorn", "uvloop", "zmq",
2626
"vllm==0.5.5", "outlines>=0.0.44"]
2727
openai = ["openai>=1.0", "tiktoken"]
2828
anthropic = ["anthropic>=0.20.0"]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
Common utilities for torchao.
3+
"""
4+
5+
import torch
6+
from torchao.quantization import (
7+
int4_weight_only,
8+
int8_dynamic_activation_int8_weight,
9+
int8_weight_only,
10+
quantize_,
11+
)
12+
13+
14+
def torchao_quantize_param_data(param, torchao_config):
15+
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
16+
dummy_linear.weight = param
17+
if "int8wo" in torchao_config:
18+
quantize_(dummy_linear, int8_weight_only())
19+
elif "int8dq" in torchao_config:
20+
quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
21+
elif "int4wo" in torchao_config:
22+
group_size = int(torchao_config.split("-")[-1])
23+
assert group_size in [
24+
32,
25+
64,
26+
128,
27+
256,
28+
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
29+
quantize_(dummy_linear, int4_weight_only(group_size=group_size))
30+
elif "fp8wo" in torchao_config:
31+
from torchao.quantization import float8_weight_only
32+
33+
# this requires newer hardware
34+
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
35+
quantize_(dummy_linear, float8_weight_only())
36+
return dummy_linear.weight

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
9898
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
9999
"enable_mla": server_args.enable_mla,
100+
"torchao_config": server_args.torchao_config,
100101
}
101102
)
102103

python/sglang/srt/models/llama.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
4343
from sglang.srt.layers.radix_attention import RadixAttention
4444
from sglang.srt.layers.sampler import Sampler
45+
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
46+
from sglang.srt.managers.schedule_batch import global_server_args_dict
4547
from sglang.srt.model_executor.forward_batch_info import InputMetadata
4648

4749

@@ -299,6 +301,7 @@ def __init__(
299301
super().__init__()
300302
self.config = config
301303
self.quant_config = quant_config
304+
self.torchao_config = global_server_args_dict["torchao_config"]
302305
self.model = LlamaModel(config, quant_config=quant_config)
303306
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
304307
self.logits_processor = LogitsProcessor(config)
@@ -361,6 +364,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
361364
weight_loader = getattr(param, "weight_loader", default_weight_loader)
362365
weight_loader(param, loaded_weight)
363366

367+
if self.torchao_config:
368+
if name.endswith("proj.weight") and param.ndim == 2:
369+
params_dict[name] = torchao_quantize_param_data(
370+
param, self.torchao_config
371+
)
372+
373+
if self.torchao_config:
374+
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
375+
stacked_params = set(entry[0] for entry in stacked_params_mapping)
376+
for param_suffix in stacked_params:
377+
for name in params_dict:
378+
if param_suffix in name:
379+
param = params_dict[name]
380+
params_dict[name] = torchao_quantize_param_data(
381+
param, self.torchao_config
382+
)
383+
384+
self.load_state_dict(params_dict, assign=True)
385+
364386

365387
class Phi3ForCausalLM(LlamaForCausalLM):
366388
pass

python/sglang/srt/server_args.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ServerArgs:
9595
disable_custom_all_reduce: bool = False
9696
enable_mixed_chunk: bool = False
9797
enable_torch_compile: bool = False
98+
torchao_config: str = ""
9899
enable_p2p_check: bool = False
99100
enable_mla: bool = False
100101
triton_attention_reduce_in_fp32: bool = False
@@ -443,7 +444,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
443444
parser.add_argument(
444445
"--enable-torch-compile",
445446
action="store_true",
446-
help="Optimize the model with torch.compile, experimental feature.",
447+
help="Optimize the model with torch.compile. Experimental feature.",
448+
)
449+
parser.add_argument(
450+
"--torchao-config",
451+
type=str,
452+
default=ServerArgs.torchao_config,
453+
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
447454
)
448455
parser.add_argument(
449456
"--enable-p2p-check",

test/srt/test_eval_accuracy_mini.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def test_mmlu(self):
2929
base_url=self.base_url,
3030
model=self.model,
3131
eval_name="mmlu",
32-
num_examples=32,
32+
num_examples=64,
3333
num_threads=32,
3434
)
3535

3636
metrics = run_eval(args)
37-
assert metrics["score"] >= 0.6
37+
assert metrics["score"] >= 0.65
3838

3939

4040
if __name__ == "__main__":

test/srt/test_moe_eval_accuracy_large.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_mmlu(self):
4242
)
4343

4444
metrics = run_eval(args)
45-
assert metrics["score"] >= 0.62, f"{metrics}"
45+
assert metrics["score"] >= 0.625, f"{metrics}"
4646

4747
def test_human_eval(self):
4848
args = SimpleNamespace(
@@ -54,7 +54,7 @@ def test_human_eval(self):
5454
)
5555

5656
metrics = run_eval(args)
57-
assert metrics["score"] >= 0.42, f"{metrics}"
57+
assert metrics["score"] >= 0.425, f"{metrics}"
5858

5959
def test_mgsm_en(self):
6060
args = SimpleNamespace(
@@ -66,7 +66,7 @@ def test_mgsm_en(self):
6666
)
6767

6868
metrics = run_eval(args)
69-
assert metrics["score"] >= 0.62, f"{metrics}"
69+
assert metrics["score"] >= 0.625, f"{metrics}"
7070

7171

7272
if __name__ == "__main__":

test/srt/test_torch_compile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def setUpClass(cls):
2222
cls.model,
2323
cls.base_url,
2424
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
25-
other_args=["--enable-torch-compile", "--disable-radix-cache"],
25+
other_args=["--enable-torch-compile"],
2626
)
2727

2828
@classmethod
@@ -34,12 +34,12 @@ def test_mmlu(self):
3434
base_url=self.base_url,
3535
model=self.model,
3636
eval_name="mmlu",
37-
num_examples=32,
37+
num_examples=64,
3838
num_threads=32,
3939
)
4040

4141
metrics = run_eval(args)
42-
assert metrics["score"] >= 0.6
42+
assert metrics["score"] >= 0.65
4343

4444
def run_decode(self, max_new_tokens):
4545
response = requests.post(

test/srt/test_torchao.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import unittest
2+
from types import SimpleNamespace
3+
4+
import requests
5+
6+
from sglang.srt.utils import kill_child_process
7+
from sglang.test.run_eval import run_eval
8+
from sglang.test.test_utils import (
9+
DEFAULT_MODEL_NAME_FOR_TEST,
10+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
11+
DEFAULT_URL_FOR_TEST,
12+
popen_launch_server,
13+
)
14+
15+
16+
class TestTorchCompile(unittest.TestCase):
17+
@classmethod
18+
def setUpClass(cls):
19+
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
20+
cls.base_url = DEFAULT_URL_FOR_TEST
21+
cls.process = popen_launch_server(
22+
cls.model,
23+
cls.base_url,
24+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
25+
other_args=["--torchao-config", "int4wo-128"],
26+
)
27+
28+
@classmethod
29+
def tearDownClass(cls):
30+
kill_child_process(cls.process.pid)
31+
32+
def test_mmlu(self):
33+
args = SimpleNamespace(
34+
base_url=self.base_url,
35+
model=self.model,
36+
eval_name="mmlu",
37+
num_examples=64,
38+
num_threads=32,
39+
)
40+
41+
metrics = run_eval(args)
42+
assert metrics["score"] >= 0.65
43+
44+
def run_decode(self, max_new_tokens):
45+
response = requests.post(
46+
self.base_url + "/generate",
47+
json={
48+
"text": "The capital of France is",
49+
"sampling_params": {
50+
"temperature": 0,
51+
"max_new_tokens": max_new_tokens,
52+
},
53+
"ignore_eos": True,
54+
},
55+
)
56+
return response.json()
57+
58+
def test_throughput(self):
59+
import time
60+
61+
max_tokens = 256
62+
63+
tic = time.time()
64+
res = self.run_decode(max_tokens)
65+
tok = time.time()
66+
print(res["text"])
67+
throughput = max_tokens / (tok - tic)
68+
print(f"Throughput: {throughput} tokens/s")
69+
assert throughput >= 210
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()

test/srt/test_triton_attn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def test_mmlu(self):
3232
base_url=self.base_url,
3333
model=self.model,
3434
eval_name="mmlu",
35-
num_examples=32,
35+
num_examples=64,
3636
num_threads=32,
3737
)
3838

3939
metrics = run_eval(args)
40-
assert metrics["score"] >= 0.6
40+
assert metrics["score"] >= 0.65
4141

4242

4343
if __name__ == "__main__":

0 commit comments

Comments
 (0)