|
8 | 8 |
|
9 | 9 | import copy
|
10 | 10 | import logging
|
11 |
| -import statistics |
12 | 11 | import threading
|
13 | 12 | import time
|
14 |
| -from typing import Callable, List, Optional, Tuple |
| 13 | +from typing import List, Tuple |
15 | 14 |
|
16 | 15 | import torch
|
17 |
| -from fbgemm_gpu.tbe.utils import b_indices, TBERequest # noqa: F401 |
18 | 16 |
|
19 | 17 | logging.basicConfig(level=logging.DEBUG)
|
20 | 18 |
|
21 | 19 |
|
22 |
| -def warmup( |
23 |
| - request: TBERequest, |
24 |
| - warmup_ms: int, |
25 |
| - warmup_runs: int, |
26 |
| - func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], |
27 |
| - bwd_only: bool = False, |
28 |
| - grad: Optional[torch.Tensor] = None, |
29 |
| -) -> None: |
30 |
| - indices, offsets, weights = request.unpack_3() |
31 |
| - if warmup_ms: |
32 |
| - start_time_ms = time.time() * 1000 |
33 |
| - while time.time() * 1000 - start_time_ms < warmup_ms: |
34 |
| - out = func(indices, offsets, weights) |
35 |
| - if bwd_only: |
36 |
| - out.backward(grad) |
37 |
| - else: |
38 |
| - for _ in range(warmup_runs): |
39 |
| - out = func(indices, offsets, weights) |
40 |
| - if bwd_only: |
41 |
| - out.backward(grad) |
42 |
| - |
43 |
| - |
44 | 20 | def benchmark_torch_function( # noqa: C901
|
45 | 21 | # pyre-fixme[2]: Parameter must be annotated.
|
46 | 22 | f,
|
47 | 23 | # pyre-fixme[2]: Parameter must be annotated.
|
48 | 24 | args,
|
49 | 25 | # pyre-fixme[2]: Parameter must be annotated.
|
50 |
| - kwargs={}, |
| 26 | + kwargs={}, # noqa: B006 |
51 | 27 | flush_gpu_cache_size_mb: int = 40,
|
52 | 28 | iters: int = 10,
|
53 | 29 | num_warmups: int = 2,
|
@@ -153,316 +129,3 @@ def forward(idx: int) -> None:
|
153 | 129 |
|
154 | 130 | # pyre-fixme[61]: `output` is undefined, or not always defined.
|
155 | 131 | return float(elapsed_time), output
|
156 |
| - |
157 |
| - |
158 |
| -def benchmark_requests( |
159 |
| - requests: List[TBERequest], |
160 |
| - func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], |
161 |
| - flush_gpu_cache_size_mb: int = 0, |
162 |
| - check_median: bool = False, |
163 |
| - num_warmups: int = 0, |
164 |
| - bwd_only: bool = False, |
165 |
| - grad: Optional[torch.Tensor] = None, |
166 |
| - # Used to label benchmark iterations differently in nsys profile result |
167 |
| - # so that we can compare performance of two different models for example. |
168 |
| - # If empty string is provided, it won't have any effect. |
169 |
| - nvtx_range: str = "", |
170 |
| - # Can be used to clear model's stats after warmup for example. |
171 |
| - callback_after_warmup: Optional[Callable[[], None]] = None, |
172 |
| - periodic_logs: bool = False, |
173 |
| - warmup_ms: Optional[int] = None, |
174 |
| - iters: int = -1, |
175 |
| -) -> float: |
176 |
| - times = [] |
177 |
| - # Run at least one warmup iteration to avoid the long cudaLaunchKernel time |
178 |
| - # for the first kernel if warmup_ms > 0 |
179 |
| - # warmup_ms is prioritized over num_warmups |
180 |
| - |
181 |
| - if warmup_ms is None: |
182 |
| - num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 |
183 |
| - |
184 |
| - # warm-up the GPU before profiling |
185 |
| - warmup( |
186 |
| - requests[0], |
187 |
| - # pyre-ignore[6] |
188 |
| - warmup_ms, |
189 |
| - num_warmups, |
190 |
| - lambda indices, offsets, per_sample_weights: func( |
191 |
| - indices, |
192 |
| - offsets, |
193 |
| - per_sample_weights, |
194 |
| - ), |
195 |
| - bwd_only=bwd_only, |
196 |
| - grad=grad, |
197 |
| - ) |
198 |
| - |
199 |
| - if callback_after_warmup is not None: |
200 |
| - callback_after_warmup() |
201 |
| - |
202 |
| - num_reqs = len(requests) |
203 |
| - iters = num_reqs if iters == -1 else iters |
204 |
| - |
205 |
| - if torch.cuda.is_available(): |
206 |
| - torch.cuda.synchronize() |
207 |
| - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
208 |
| - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
209 |
| - else: |
210 |
| - start_events = [] |
211 |
| - end_events = [] |
212 |
| - |
213 |
| - for it in range(iters): |
214 |
| - req = requests[it % num_reqs] |
215 |
| - |
216 |
| - indices, offsets, weights = req.unpack_3() |
217 |
| - if bwd_only: |
218 |
| - # Run forward before profiling if does backward only |
219 |
| - out = func(indices, offsets, weights) |
220 |
| - start_time = time.time() |
221 |
| - if torch.cuda.is_available(): |
222 |
| - if flush_gpu_cache_size_mb: |
223 |
| - _ = torch.rand( |
224 |
| - flush_gpu_cache_size_mb * 1024 * 1024 // 4, |
225 |
| - dtype=torch.float, |
226 |
| - device="cuda", |
227 |
| - ) |
228 |
| - start_events[it].record() |
229 |
| - |
230 |
| - if nvtx_range: |
231 |
| - torch.cuda.nvtx.range_push(f"{nvtx_range}-{it}") |
232 |
| - |
233 |
| - if bwd_only: |
234 |
| - out.backward(grad) |
235 |
| - else: |
236 |
| - func(indices, offsets, weights) |
237 |
| - |
238 |
| - if nvtx_range: |
239 |
| - torch.cuda.nvtx.range_pop() |
240 |
| - |
241 |
| - if torch.cuda.is_available(): |
242 |
| - end_events[it].record() |
243 |
| - else: |
244 |
| - it_time = time.time() - start_time |
245 |
| - times.append(it_time) |
246 |
| - |
247 |
| - if torch.cuda.is_available(): |
248 |
| - torch.cuda.synchronize() |
249 |
| - times = [ |
250 |
| - start.elapsed_time(end) * 1.0e-3 |
251 |
| - for start, end in zip(start_events, end_events) |
252 |
| - ] |
253 |
| - |
254 |
| - if periodic_logs: |
255 |
| - for it in range(100, iters + 1, 100): |
256 |
| - times_ = times[0:it] |
257 |
| - avg_time = sum(times_) / len(times_) * 1.0e6 |
258 |
| - last_100_avg = sum(times_[-100:]) / 100 * 1.0e6 |
259 |
| - logging.info( |
260 |
| - f"Iteration [{it}/{len(requests)}]: Last 100: {last_100_avg:.2f} us, Running avg: {avg_time:.2f} us" |
261 |
| - ) |
262 |
| - |
263 |
| - avg_time = sum(times) / iters |
264 |
| - median_time = statistics.median(times) |
265 |
| - return median_time if check_median else avg_time |
266 |
| - |
267 |
| - |
268 |
| -def benchmark_requests_refer( |
269 |
| - requests: List[TBERequest], |
270 |
| - T: int, |
271 |
| - B: int, |
272 |
| - L: int, |
273 |
| - E: int, |
274 |
| - D: int, |
275 |
| - pooling_mode: str, |
276 |
| - weighted: bool, |
277 |
| - flush_gpu_cache_size_mb: int = 0, |
278 |
| - check_median: bool = False, |
279 |
| -) -> float: |
280 |
| - do_pooling = pooling_mode in ["sum", "mean"] |
281 |
| - |
282 |
| - if do_pooling: |
283 |
| - nn_embedding_list = [ |
284 |
| - torch.nn.EmbeddingBag(E, D, mode=pooling_mode, sparse=True).cuda() |
285 |
| - ] * T |
286 |
| - else: |
287 |
| - nn_embedding_list = [torch.nn.Embedding(E, D, sparse=True).cuda()] * T |
288 |
| - |
289 |
| - times = [] |
290 |
| - if torch.cuda.is_available(): |
291 |
| - torch.cuda.synchronize() |
292 |
| - start_event = torch.cuda.Event(enable_timing=True) |
293 |
| - end_event = torch.cuda.Event(enable_timing=True) |
294 |
| - for req in requests: |
295 |
| - indices, _, weights = req.unpack_3() |
296 |
| - indices_list = indices.view(T, B, L).split(1) |
297 |
| - |
298 |
| - if weighted: |
299 |
| - assert weights is not None |
300 |
| - weights_list = weights.view(T, B, L).split(1) |
301 |
| - |
302 |
| - start_time = time.time() |
303 |
| - if torch.cuda.is_available(): |
304 |
| - if flush_gpu_cache_size_mb: |
305 |
| - _ = torch.rand( |
306 |
| - flush_gpu_cache_size_mb * 1024 * 1024 // 4, |
307 |
| - dtype=torch.float, |
308 |
| - device="cuda", |
309 |
| - ) |
310 |
| - torch.cuda.synchronize() |
311 |
| - start_event.record() |
312 |
| - |
313 |
| - nn_embedding_output = ( |
314 |
| - [ |
315 |
| - b_indices(nn_embedding, x, use_cpu=False, do_pooling=do_pooling) |
316 |
| - for (nn_embedding, x) in zip(nn_embedding_list, indices_list) |
317 |
| - ] |
318 |
| - if not weighted |
319 |
| - else [ |
320 |
| - b_indices( |
321 |
| - nn_embedding, |
322 |
| - x, |
323 |
| - per_sample_weights=xw.view(-1), |
324 |
| - use_cpu=False, |
325 |
| - do_pooling=do_pooling, |
326 |
| - ) |
327 |
| - for (nn_embedding, x, xw) in zip( |
328 |
| - nn_embedding_list, |
329 |
| - indices_list, |
330 |
| - # pyre-fixme[61]: `weights_list` is undefined, or not always |
331 |
| - # defined. |
332 |
| - weights_list, |
333 |
| - ) |
334 |
| - ] |
335 |
| - ) |
336 |
| - |
337 |
| - if do_pooling: |
338 |
| - final_output = torch.cat( |
339 |
| - [f.view(B, -1) for f in nn_embedding_output], dim=1 |
340 |
| - ) |
341 |
| - else: |
342 |
| - final_output = torch.cat(nn_embedding_output, dim=0).view( # noqa: F841 |
343 |
| - -1, D |
344 |
| - ) |
345 |
| - |
346 |
| - if torch.cuda.is_available(): |
347 |
| - end_event.record() |
348 |
| - torch.cuda.synchronize() |
349 |
| - # pyre-fixme[61]: `end_event` is undefined, or not always defined. |
350 |
| - it_time = start_event.elapsed_time(end_event) * 1.0e-3 |
351 |
| - times.append(it_time) |
352 |
| - else: |
353 |
| - it_time = time.time() - start_time |
354 |
| - times.append(it_time) |
355 |
| - avg_time = sum(times) / len(requests) |
356 |
| - median_time = statistics.median(times) |
357 |
| - return median_time if check_median else avg_time |
358 |
| - |
359 |
| - |
360 |
| -def benchmark_pipelined_requests( |
361 |
| - requests: List[TBERequest], |
362 |
| - func1: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None], |
363 |
| - func2: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None], |
364 |
| - flush_gpu_cache_size_mb: int = 0, |
365 |
| - check_median: bool = False, |
366 |
| -) -> Tuple[float, float]: |
367 |
| - torch.cuda.synchronize() |
368 |
| - start_events = [ |
369 |
| - (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) |
370 |
| - for _ in requests |
371 |
| - ] |
372 |
| - end_events = [ |
373 |
| - (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) |
374 |
| - for _ in requests |
375 |
| - ] |
376 |
| - for req, start_event, end_event in zip(requests, start_events, end_events): |
377 |
| - indices, offsets, indices_weights = req.unpack_3() |
378 |
| - if flush_gpu_cache_size_mb: |
379 |
| - _ = torch.rand( |
380 |
| - flush_gpu_cache_size_mb * 1024 * 1024 // 4, |
381 |
| - dtype=torch.float, |
382 |
| - device="cuda", |
383 |
| - ) |
384 |
| - torch.cuda.synchronize() |
385 |
| - start_event[0].record() |
386 |
| - func1(indices, offsets, indices_weights) |
387 |
| - end_event[0].record() |
388 |
| - start_event[1].record() |
389 |
| - func2(indices, offsets, indices_weights) |
390 |
| - end_event[1].record() |
391 |
| - torch.cuda.synchronize() |
392 |
| - avg_time = ( |
393 |
| - sum( |
394 |
| - start_event[0].elapsed_time(end_event[0]) * 1.0e-3 |
395 |
| - for start_event, end_event in zip(start_events, end_events) |
396 |
| - ) |
397 |
| - / len(requests), |
398 |
| - sum( |
399 |
| - start_event[1].elapsed_time(end_event[1]) * 1.0e-3 |
400 |
| - for start_event, end_event in zip(start_events, end_events) |
401 |
| - ) |
402 |
| - / len(requests), |
403 |
| - ) |
404 |
| - median_time = ( |
405 |
| - statistics.median( |
406 |
| - start_event[0].elapsed_time(end_event[0]) * 1.0e-3 |
407 |
| - for start_event, end_event in zip(start_events, end_events) |
408 |
| - ), |
409 |
| - statistics.median( |
410 |
| - start_event[1].elapsed_time(end_event[1]) * 1.0e-3 |
411 |
| - for start_event, end_event in zip(start_events, end_events) |
412 |
| - ), |
413 |
| - ) |
414 |
| - return median_time if check_median else avg_time |
415 |
| - |
416 |
| - |
417 |
| -def benchmark_vbe( |
418 |
| - requests: List[Tuple[torch.Tensor, torch.Tensor]], |
419 |
| - func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
420 |
| -) -> Tuple[float, float]: |
421 |
| - """ |
422 |
| - A benchmark function to return the average execution time in seconds of |
423 |
| - forward and backward of VBE kernels. |
424 |
| -
|
425 |
| - Args: |
426 |
| - requests (List[Tuple[torch.Tensor, torch.Tensor]]): |
427 |
| - A list of requests. Each request is a tuple |
428 |
| - of indices and offsets. |
429 |
| -
|
430 |
| - func (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): |
431 |
| - A function that takes in indices and offsets |
432 |
| - and returns the output of the VBE kernel. |
433 |
| -
|
434 |
| - Returns: |
435 |
| - Tuple[float, float]: |
436 |
| - A tuple of average execution time in seconds of forward and |
437 |
| - backward of VBE kernels. |
438 |
| - """ |
439 |
| - |
440 |
| - fwd_times = [] |
441 |
| - bwd_times = [] |
442 |
| - |
443 |
| - torch.cuda.synchronize() |
444 |
| - start_event = torch.cuda.Event(enable_timing=True) |
445 |
| - end_event = torch.cuda.Event(enable_timing=True) |
446 |
| - |
447 |
| - for indices, offsets in requests: |
448 |
| - # forward |
449 |
| - start_event.record() |
450 |
| - out = func(indices, offsets) |
451 |
| - end_event.record() |
452 |
| - torch.cuda.synchronize() |
453 |
| - it_time = start_event.elapsed_time(end_event) * 1.0e-3 |
454 |
| - fwd_times.append(it_time) |
455 |
| - |
456 |
| - grad = torch.rand_like(out) |
457 |
| - start_event.record() |
458 |
| - # backward |
459 |
| - out.backward(grad) |
460 |
| - end_event.record() |
461 |
| - torch.cuda.synchronize() |
462 |
| - it_time = start_event.elapsed_time(end_event) * 1.0e-3 |
463 |
| - bwd_times.append(it_time) |
464 |
| - |
465 |
| - fwd_time_sec = statistics.median(fwd_times) |
466 |
| - bwd_time_sec = statistics.median(bwd_times) |
467 |
| - |
468 |
| - return fwd_time_sec, bwd_time_sec |
0 commit comments