Skip to content

Commit 6903945

Browse files
fzyzcjyLayssy
authored andcommitted
Support tuning DeepEP configs (sgl-project#6742)
1 parent c9aab1c commit 6903945

File tree

2 files changed

+694
-0
lines changed

2 files changed

+694
-0
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# ADAPTED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py
2+
3+
import os
4+
import sys
5+
from typing import Optional
6+
7+
import numpy as np
8+
import torch
9+
import torch.distributed as dist
10+
11+
12+
def init_dist(local_rank: int, num_local_ranks: int, args):
13+
ip = args.master_addr
14+
port = args.master_port
15+
num_nodes = args.nnodes
16+
node_rank = args.node_rank
17+
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
18+
19+
dist.init_process_group(
20+
backend="nccl",
21+
init_method=f"tcp://{ip}:{port}",
22+
world_size=num_nodes * num_local_ranks,
23+
rank=node_rank * num_local_ranks + local_rank,
24+
)
25+
torch.set_default_dtype(torch.bfloat16)
26+
torch.set_default_device("cuda")
27+
torch.cuda.set_device(local_rank)
28+
29+
return (
30+
dist.get_rank(),
31+
dist.get_world_size(),
32+
dist.new_group(list(range(num_local_ranks * num_nodes))),
33+
)
34+
35+
36+
def calc_diff(x: torch.Tensor, y: torch.Tensor):
37+
x, y = x.double() + 1, y.double() + 1
38+
denominator = (x * x + y * y).sum()
39+
sim = 2 * (x * y).sum() / denominator
40+
return (1 - sim).item()
41+
42+
43+
def per_token_cast_to_fp8(x: torch.Tensor):
44+
assert x.dim() == 2 and x.size(1) % 128 == 0
45+
m, n = x.shape
46+
x_view = x.view(m, -1, 128)
47+
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
48+
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
49+
m, n
50+
), (x_amax / 448.0).view(m, -1)
51+
52+
53+
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
54+
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
55+
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
56+
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
57+
58+
59+
def inplace_unique(x: torch.Tensor, num_slots: int):
60+
assert x.dim() == 2
61+
mask = x < 0
62+
x_padded = x.masked_fill(mask, num_slots)
63+
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
64+
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
65+
bin_count = bin_count[:, :num_slots]
66+
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
67+
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
68+
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
69+
x[:, :].fill_(-1)
70+
valid_len = min(num_slots, x.size(1))
71+
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
72+
73+
74+
def create_grouped_scores(
75+
scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int
76+
):
77+
num_tokens, num_experts = scores.shape
78+
scores = scores.view(num_tokens, num_groups, -1)
79+
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
80+
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
81+
return (scores * mask).view(num_tokens, num_experts)
82+
83+
84+
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
85+
# Flush L2 cache with 256 MB data
86+
torch.cuda.synchronize()
87+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
88+
89+
# Warmup
90+
for _ in range(num_warmups):
91+
fn()
92+
93+
# Flush L2
94+
cache.zero_()
95+
96+
# Testing
97+
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
98+
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
99+
for i in range(num_tests):
100+
# Record
101+
start_events[i].record()
102+
fn()
103+
end_events[i].record()
104+
if post_fn is not None:
105+
post_fn()
106+
torch.cuda.synchronize()
107+
108+
times = np.array(
109+
[s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]
110+
)[1:]
111+
return np.average(times), np.min(times), np.max(times)
112+
113+
114+
class empty_suppress:
115+
def __enter__(self):
116+
return self
117+
118+
def __exit__(self, *_):
119+
pass
120+
121+
122+
class suppress_stdout_stderr:
123+
def __enter__(self):
124+
self.outnull_file = open(os.devnull, "w")
125+
self.errnull_file = open(os.devnull, "w")
126+
127+
self.old_stdout_fileno_undup = sys.stdout.fileno()
128+
self.old_stderr_fileno_undup = sys.stderr.fileno()
129+
130+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
131+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
132+
133+
self.old_stdout = sys.stdout
134+
self.old_stderr = sys.stderr
135+
136+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
137+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
138+
139+
sys.stdout = self.outnull_file
140+
sys.stderr = self.errnull_file
141+
return self
142+
143+
def __exit__(self, *_):
144+
sys.stdout = self.old_stdout
145+
sys.stderr = self.old_stderr
146+
147+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
148+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
149+
150+
os.close(self.old_stdout_fileno)
151+
os.close(self.old_stderr_fileno)
152+
153+
self.outnull_file.close()
154+
self.errnull_file.close()
155+
156+
157+
def bench_kineto(
158+
fn,
159+
kernel_names,
160+
num_tests: int = 30,
161+
suppress_kineto_output: bool = False,
162+
trace_path: Optional[str] = None,
163+
barrier_comm_profiling: bool = False,
164+
):
165+
# Profile
166+
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
167+
with suppress():
168+
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
169+
with torch.profiler.profile(
170+
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
171+
) as prof:
172+
for i in range(2):
173+
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
174+
if barrier_comm_profiling:
175+
lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
176+
rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
177+
lhs @ rhs
178+
dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda"))
179+
for _ in range(num_tests):
180+
fn()
181+
prof.step()
182+
183+
# Parse the profiling table
184+
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
185+
is_tupled = isinstance(kernel_names, tuple)
186+
prof_lines = (
187+
prof.key_averages()
188+
.table(sort_by="cuda_time_total", max_name_column_width=100)
189+
.split("\n")
190+
)
191+
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
192+
assert all([isinstance(name, str) for name in kernel_names])
193+
for name in kernel_names:
194+
assert (
195+
sum([name in line for line in prof_lines]) == 1
196+
), f"Errors of the kernel {name} in the profiling table"
197+
198+
# Save chrome traces
199+
if trace_path is not None:
200+
prof.export_chrome_trace(trace_path)
201+
202+
# Return average kernel times
203+
units = {"ms": 1e3, "us": 1e6}
204+
kernel_times = []
205+
for name in kernel_names:
206+
for line in prof_lines:
207+
if name in line:
208+
time_str = line.split()[-2]
209+
for unit, scale in units.items():
210+
if unit in time_str:
211+
kernel_times.append(float(time_str.replace(unit, "")) / scale)
212+
break
213+
break
214+
return tuple(kernel_times) if is_tupled else kernel_times[0]
215+
216+
217+
def hash_tensor(t: torch.Tensor):
218+
return t.view(torch.int64).sum().item()

0 commit comments

Comments
 (0)