Skip to content

Commit 65cd9f0

Browse files
levythufacebook-github-bot
authored andcommitted
GPU timing and basic reporting framework (rebase of D52716004) (pytorch#2314)
Summary: Implements the reporting framework for internal state per TBE for better visibility. Differential Revision: D53028585
1 parent ad70943 commit 65cd9f0

File tree

4 files changed

+68
-1
lines changed

4 files changed

+68
-1
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import abc
9+
10+
11+
class IEmbeddingOffloadingMetricsReporter(abc.ABC):
12+
"""
13+
All the report_XXX functions should be light weighted and fail-safe.
14+
"""
15+
16+
@abc.abstractmethod
17+
def should_report(self, iteration_step: int) -> bool:
18+
"""
19+
Return whether we should report metrics during this step.
20+
This function should be cheap, side-effect free and return immediately.
21+
"""
22+
...
23+
24+
@abc.abstractmethod
25+
def report_duration(
26+
self,
27+
iteration_step: int,
28+
event_name: str,
29+
duration_ms: float,
30+
embedding_id: str = "",
31+
tbe_id: str = "",
32+
) -> None:
33+
"""
34+
Report the duration of a timed event.
35+
"""
36+
...

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch import nn, Tensor # usort:skip
2121

2222
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
23+
from fbgemm_gpu.embedding_offloading_metrics import IEmbeddingOffloadingMetricsReporter
2324
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
2425
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
2526
BoundsCheckMode,
@@ -348,6 +349,7 @@ def __init__( # noqa C901
348349
# If a separate stream is used for prefetch, the optional forward_stream arg of prefetch function
349350
# should be set.
350351
prefetch_pipeline: bool = False,
352+
metrics_reporter: Optional[IEmbeddingOffloadingMetricsReporter] = None,
351353
) -> None:
352354
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
353355

@@ -441,6 +443,8 @@ def __init__( # noqa C901
441443
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
442444
# 4: N_conflict_unique_misses, 5: N_conflict_misses
443445

446+
self.metrics_reporter = metrics_reporter
447+
444448
self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET
445449

446450
self.feature_table_map: List[int] = (

fbgemm_gpu/test/tbe/cache/cache_common.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
# pyre-ignore-all-errors[56]
99

10-
from typing import Tuple
10+
from typing import List, Optional, Tuple, Union
1111

1212
import numpy as np
1313
import torch
14+
from fbgemm_gpu.embedding_offloading_metrics import IEmbeddingOffloadingMetricsReporter
1415
from fbgemm_gpu.split_embedding_configs import SparseType
1516

1617
from fbgemm_gpu.split_embedding_utils import round_up
@@ -37,6 +38,27 @@
3738
VERBOSITY: Verbosity = Verbosity.verbose
3839

3940

41+
class TestingEmbeddingOffloadingMetricsReporter(IEmbeddingOffloadingMetricsReporter):
42+
def __init__(self, reporting_interval: int = 1) -> None:
43+
self.reported_data: List[List[Union[int, str, float]]] = []
44+
self.reporting_interval = reporting_interval
45+
46+
def should_report(self, iteration_step: int) -> bool:
47+
return (iteration_step - 1) % self.reporting_interval == 0
48+
49+
def report_duration(
50+
self,
51+
iteration_step: int,
52+
event_name: str,
53+
duration_ms: float,
54+
embedding_id: str = "",
55+
tbe_id: str = "",
56+
) -> None:
57+
self.reported_data.append(
58+
[iteration_step, event_name, duration_ms, embedding_id, tbe_id]
59+
)
60+
61+
4062
def generate_cache_tbes(
4163
T: int,
4264
D: int,
@@ -48,6 +70,7 @@ def generate_cache_tbes(
4870
cache_sets: int = 0,
4971
weights_cache_precision: SparseType = SparseType.FP32,
5072
stochastic_rounding: bool = False,
73+
reporter: Optional[TestingEmbeddingOffloadingMetricsReporter] = None,
5174
) -> Tuple[
5275
SplitTableBatchedEmbeddingBagsCodegen,
5376
SplitTableBatchedEmbeddingBagsCodegen,
@@ -103,6 +126,7 @@ def generate_cache_tbes(
103126
cache_sets=cache_sets,
104127
weights_precision=weights_cache_precision,
105128
cache_precision=weights_cache_precision,
129+
metrics_reporter=reporter,
106130
)
107131

108132
if use_int_weight:

fbgemm_gpu/test/tbe/cache/cache_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
generate_cache_tbes,
3939
gpu_unavailable,
4040
optests,
41+
TestingEmbeddingOffloadingMetricsReporter,
4142
VERBOSITY,
4243
)
4344

@@ -122,6 +123,7 @@ def _test_cache_prefetch_pipeline( # noqa C901
122123
"""
123124

124125
assert prefetch_location in ["before_fwd", "between_fwd_bwd"]
126+
reporter = TestingEmbeddingOffloadingMetricsReporter(reporting_interval=2)
125127
cc, cc_ref, min_Es, sum_Ds = generate_cache_tbes(
126128
T,
127129
D,
@@ -132,6 +134,7 @@ def _test_cache_prefetch_pipeline( # noqa C901
132134
use_int_weight=True,
133135
weights_cache_precision=weights_cache_precision,
134136
stochastic_rounding=stochastic_rounding,
137+
reporter=reporter,
135138
)
136139
iters = 5
137140
requests = generate_requests(iters, B, T, L, min_Es, reuse=0.1)

0 commit comments

Comments
 (0)