Skip to content

Commit e747934

Browse files
sryapfacebook-github-bot
authored andcommitted
Refactor generate_vbe_metadata (pytorch#3087)
Summary: Pull Request resolved: pytorch#3087 Moves `generate_vbe_metadata` into the `fbgemm_gpu.split_table_batched_embeddings_ops_training_common`. This is a preparation for VBE enablement in SSD-TBE Reviewed By: q10 Differential Revision: D62215222
1 parent 3fce106 commit e747934

File tree

2 files changed

+170
-136
lines changed

2 files changed

+170
-136
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 26 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
RecordCacheMetrics,
4646
SplitState,
4747
)
48+
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
49+
generate_vbe_metadata,
50+
is_torchdynamo_compiling,
51+
)
4852

4953
try:
5054
if torch.version.hip:
@@ -62,23 +66,6 @@
6266
pass
6367

6468

65-
try:
66-
try:
67-
from torch.compiler import is_compiling
68-
69-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
70-
# at least one test fails if we import is_compiling as a different name
71-
return is_compiling()
72-
73-
except Exception:
74-
# torch.compiler.is_compiling is not available in torch 1.10
75-
from torch._dynamo import is_compiling as is_torchdynamo_compiling
76-
except Exception:
77-
78-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
79-
return False
80-
81-
8269
DEFAULT_ASSOC = 32 if torch.version.hip is None else 64
8370
INT8_EMB_ROW_DIM_OFFSET = 8
8471

@@ -334,125 +321,6 @@ def apply_split_helper(
334321
)
335322

336323

337-
def generate_vbe_metadata(
338-
offsets: Tensor,
339-
batch_size_per_feature_per_rank: Optional[List[List[int]]],
340-
optimizer: OptimType,
341-
pooling_mode: PoolingMode,
342-
feature_dims_cpu: Tensor,
343-
device: torch.device,
344-
) -> invokers.lookup_args.VBEMetadata:
345-
"""
346-
Generate VBE metadata based on batch_size_per_feature_per_rank.
347-
Metadata includes:
348-
1) B_offsets - A tensor that contains batch size offsets for each
349-
feature
350-
2) output_offsets_feature_rank - A tensor that contains output
351-
offsets for each feature
352-
3) B_offsets_per_rank_per_feature - A tensor that contains batch
353-
size offsets for each feature
354-
and rank
355-
4) max_B - The maximum batch size for all features
356-
5) max_B_feature_rank - The maximum batch size for all ranks and
357-
features
358-
6) output_size - The output size (number of elements)
359-
"""
360-
if batch_size_per_feature_per_rank is not None:
361-
assert optimizer in (
362-
OptimType.EXACT_ROWWISE_ADAGRAD,
363-
OptimType.EXACT_SGD,
364-
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
365-
OptimType.NONE,
366-
), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD and ENSEMBLE_ROWWISE_ADAGRAD only"
367-
assert (
368-
pooling_mode != PoolingMode.NONE
369-
), "Variable batch size TBE support is not enabled for PoolingMode.NONE"
370-
# TODO: Add input check
371-
zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32)
372-
373-
# Create B offsets
374-
total_batch_size_per_feature = torch.tensor(
375-
batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu"
376-
).sum(dim=1)
377-
378-
max_B = total_batch_size_per_feature.max().item()
379-
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
380-
torch._check_is_size(max_B)
381-
torch._check(max_B < offsets.numel())
382-
383-
Bs = torch.concat([zero_tensor, total_batch_size_per_feature])
384-
B_offsets = Bs.cumsum(dim=0).to(torch.int)
385-
386-
# Create output offsets
387-
B_feature_rank = torch.tensor(
388-
batch_size_per_feature_per_rank,
389-
device="cpu",
390-
dtype=torch.int64,
391-
)
392-
max_B_feature_rank = B_feature_rank.max().item()
393-
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
394-
torch._check_is_size(max_B_feature_rank)
395-
torch._check(max_B_feature_rank <= offsets.size(0))
396-
output_sizes_feature_rank = B_feature_rank.transpose(
397-
0, 1
398-
) * feature_dims_cpu.view(1, -1)
399-
output_offsets_feature_rank = torch.concat(
400-
[
401-
zero_tensor.to(torch.int64),
402-
output_sizes_feature_rank.flatten().cumsum(dim=0),
403-
]
404-
)
405-
output_size = output_offsets_feature_rank[-1].item()
406-
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
407-
torch._check_is_size(output_size)
408-
409-
# TODO: Support INT8 output
410-
# B_offsets_rank_per_feature is for rank and (b, t) mapping
411-
B_offsets_rank_per_feature = (
412-
torch.tensor(
413-
[
414-
[0] + batch_size_per_feature
415-
for batch_size_per_feature in batch_size_per_feature_per_rank
416-
],
417-
device="cpu",
418-
dtype=torch.int32,
419-
)
420-
.cumsum(dim=1)
421-
.to(torch.int)
422-
)
423-
424-
B_offsets = B_offsets.to(device, non_blocking=True)
425-
output_offsets_feature_rank = output_offsets_feature_rank.to(
426-
device, non_blocking=True
427-
)
428-
B_offsets_rank_per_feature = B_offsets_rank_per_feature.to(
429-
device, non_blocking=True
430-
)
431-
432-
# TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank
433-
vbe_metadata = invokers.lookup_args.VBEMetadata(
434-
B_offsets=B_offsets,
435-
output_offsets_feature_rank=output_offsets_feature_rank,
436-
B_offsets_rank_per_feature=B_offsets_rank_per_feature,
437-
# pyre-ignore
438-
max_B=max_B,
439-
# pyre-ignore
440-
max_B_feature_rank=max_B_feature_rank,
441-
# pyre-ignore
442-
output_size=output_size,
443-
)
444-
else:
445-
vbe_metadata = invokers.lookup_args.VBEMetadata(
446-
B_offsets=None,
447-
output_offsets_feature_rank=None,
448-
B_offsets_rank_per_feature=None,
449-
max_B=-1,
450-
max_B_feature_rank=-1,
451-
output_size=-1,
452-
)
453-
return vbe_metadata
454-
455-
456324
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
457325
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
458326
class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
@@ -1379,6 +1247,17 @@ def _generate_vbe_metadata(
13791247
) -> invokers.lookup_args.VBEMetadata:
13801248
# Blocking D2H copy, but only runs at first call
13811249
self.feature_dims = self.feature_dims.cpu()
1250+
if batch_size_per_feature_per_rank is not None:
1251+
assert self.optimizer in (
1252+
OptimType.EXACT_ROWWISE_ADAGRAD,
1253+
OptimType.EXACT_SGD,
1254+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
1255+
OptimType.NONE,
1256+
), (
1257+
"Variable batch size TBE support is enabled for "
1258+
"OptimType.EXACT_ROWWISE_ADAGRAD and "
1259+
"ENSEMBLE_ROWWISE_ADAGRAD only"
1260+
)
13821261
return generate_vbe_metadata(
13831262
offsets,
13841263
batch_size_per_feature_per_rank,
@@ -3043,6 +2922,17 @@ def _generate_vbe_metadata(
30432922
) -> invokers.lookup_args.VBEMetadata:
30442923
# Blocking D2H copy, but only runs at first call
30452924
self.feature_dims = self.feature_dims.cpu()
2925+
if batch_size_per_feature_per_rank is not None:
2926+
assert self.optimizer in (
2927+
OptimType.EXACT_ROWWISE_ADAGRAD,
2928+
OptimType.EXACT_SGD,
2929+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
2930+
OptimType.NONE,
2931+
), (
2932+
"Variable batch size TBE support is enabled for "
2933+
"OptimType.EXACT_ROWWISE_ADAGRAD and "
2934+
"ENSEMBLE_ROWWISE_ADAGRAD only"
2935+
)
30462936
return generate_vbe_metadata(
30472937
offsets,
30482938
batch_size_per_feature_per_rank,
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import List, Optional
9+
10+
import torch
11+
from torch import Tensor
12+
13+
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
14+
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
15+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
16+
17+
try:
18+
try:
19+
from torch.compiler import is_compiling
20+
21+
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
22+
# at least one test fails if we import is_compiling as a different name
23+
return is_compiling()
24+
25+
except Exception:
26+
# torch.compiler.is_compiling is not available in torch 1.10
27+
from torch._dynamo import is_compiling as is_torchdynamo_compiling
28+
except Exception:
29+
30+
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
31+
return False
32+
33+
34+
def generate_vbe_metadata(
35+
offsets: Tensor,
36+
batch_size_per_feature_per_rank: Optional[List[List[int]]],
37+
optimizer: OptimType,
38+
pooling_mode: PoolingMode,
39+
feature_dims_cpu: Tensor,
40+
device: torch.device,
41+
) -> invokers.lookup_args.VBEMetadata:
42+
"""
43+
Generate VBE metadata based on batch_size_per_feature_per_rank.
44+
Metadata includes:
45+
1) B_offsets - A tensor that contains batch size offsets for each
46+
feature
47+
2) output_offsets_feature_rank - A tensor that contains output
48+
offsets for each feature
49+
3) B_offsets_per_rank_per_feature - A tensor that contains batch
50+
size offsets for each feature
51+
and rank
52+
4) max_B - The maximum batch size for all features
53+
5) max_B_feature_rank - The maximum batch size for all ranks and
54+
features
55+
6) output_size - The output size (number of elements)
56+
"""
57+
if batch_size_per_feature_per_rank is not None:
58+
assert (
59+
pooling_mode != PoolingMode.NONE
60+
), "Variable batch size TBE support is not enabled for PoolingMode.NONE"
61+
# TODO: Add input check
62+
zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32)
63+
64+
# Create B offsets
65+
total_batch_size_per_feature = torch.tensor(
66+
batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu"
67+
).sum(dim=1)
68+
69+
max_B = total_batch_size_per_feature.max().item()
70+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
71+
torch._check_is_size(max_B)
72+
torch._check(max_B < offsets.numel())
73+
74+
Bs = torch.concat([zero_tensor, total_batch_size_per_feature])
75+
B_offsets = Bs.cumsum(dim=0).to(torch.int)
76+
77+
# Create output offsets
78+
B_feature_rank = torch.tensor(
79+
batch_size_per_feature_per_rank,
80+
device="cpu",
81+
dtype=torch.int64,
82+
)
83+
max_B_feature_rank = B_feature_rank.max().item()
84+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
85+
torch._check_is_size(max_B_feature_rank)
86+
torch._check(max_B_feature_rank <= offsets.size(0))
87+
output_sizes_feature_rank = B_feature_rank.transpose(
88+
0, 1
89+
) * feature_dims_cpu.view(1, -1)
90+
output_offsets_feature_rank = torch.concat(
91+
[
92+
zero_tensor.to(torch.int64),
93+
output_sizes_feature_rank.flatten().cumsum(dim=0),
94+
]
95+
)
96+
output_size = output_offsets_feature_rank[-1].item()
97+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
98+
torch._check_is_size(output_size)
99+
100+
# TODO: Support INT8 output
101+
# B_offsets_rank_per_feature is for rank and (b, t) mapping
102+
B_offsets_rank_per_feature = (
103+
torch.tensor(
104+
[
105+
[0] + batch_size_per_feature
106+
for batch_size_per_feature in batch_size_per_feature_per_rank
107+
],
108+
device="cpu",
109+
dtype=torch.int32,
110+
)
111+
.cumsum(dim=1)
112+
.to(torch.int)
113+
)
114+
115+
B_offsets = B_offsets.to(device, non_blocking=True)
116+
output_offsets_feature_rank = output_offsets_feature_rank.to(
117+
device, non_blocking=True
118+
)
119+
B_offsets_rank_per_feature = B_offsets_rank_per_feature.to(
120+
device, non_blocking=True
121+
)
122+
123+
# TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank
124+
vbe_metadata = invokers.lookup_args.VBEMetadata(
125+
B_offsets=B_offsets,
126+
output_offsets_feature_rank=output_offsets_feature_rank,
127+
B_offsets_rank_per_feature=B_offsets_rank_per_feature,
128+
# pyre-ignore
129+
max_B=max_B,
130+
# pyre-ignore
131+
max_B_feature_rank=max_B_feature_rank,
132+
# pyre-ignore
133+
output_size=output_size,
134+
)
135+
else:
136+
vbe_metadata = invokers.lookup_args.VBEMetadata(
137+
B_offsets=None,
138+
output_offsets_feature_rank=None,
139+
B_offsets_rank_per_feature=None,
140+
max_B=-1,
141+
max_B_feature_rank=-1,
142+
output_size=-1,
143+
)
144+
return vbe_metadata

0 commit comments

Comments
 (0)