|
45 | 45 | RecordCacheMetrics,
|
46 | 46 | SplitState,
|
47 | 47 | )
|
| 48 | +from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import ( |
| 49 | + generate_vbe_metadata, |
| 50 | + is_torchdynamo_compiling, |
| 51 | +) |
48 | 52 |
|
49 | 53 | try:
|
50 | 54 | if torch.version.hip:
|
|
62 | 66 | pass
|
63 | 67 |
|
64 | 68 |
|
65 |
| -try: |
66 |
| - try: |
67 |
| - from torch.compiler import is_compiling |
68 |
| - |
69 |
| - def is_torchdynamo_compiling() -> bool: # type: ignore[misc] |
70 |
| - # at least one test fails if we import is_compiling as a different name |
71 |
| - return is_compiling() |
72 |
| - |
73 |
| - except Exception: |
74 |
| - # torch.compiler.is_compiling is not available in torch 1.10 |
75 |
| - from torch._dynamo import is_compiling as is_torchdynamo_compiling |
76 |
| -except Exception: |
77 |
| - |
78 |
| - def is_torchdynamo_compiling() -> bool: # type: ignore[misc] |
79 |
| - return False |
80 |
| - |
81 |
| - |
82 | 69 | DEFAULT_ASSOC = 32 if torch.version.hip is None else 64
|
83 | 70 | INT8_EMB_ROW_DIM_OFFSET = 8
|
84 | 71 |
|
@@ -334,125 +321,6 @@ def apply_split_helper(
|
334 | 321 | )
|
335 | 322 |
|
336 | 323 |
|
337 |
| -def generate_vbe_metadata( |
338 |
| - offsets: Tensor, |
339 |
| - batch_size_per_feature_per_rank: Optional[List[List[int]]], |
340 |
| - optimizer: OptimType, |
341 |
| - pooling_mode: PoolingMode, |
342 |
| - feature_dims_cpu: Tensor, |
343 |
| - device: torch.device, |
344 |
| -) -> invokers.lookup_args.VBEMetadata: |
345 |
| - """ |
346 |
| - Generate VBE metadata based on batch_size_per_feature_per_rank. |
347 |
| - Metadata includes: |
348 |
| - 1) B_offsets - A tensor that contains batch size offsets for each |
349 |
| - feature |
350 |
| - 2) output_offsets_feature_rank - A tensor that contains output |
351 |
| - offsets for each feature |
352 |
| - 3) B_offsets_per_rank_per_feature - A tensor that contains batch |
353 |
| - size offsets for each feature |
354 |
| - and rank |
355 |
| - 4) max_B - The maximum batch size for all features |
356 |
| - 5) max_B_feature_rank - The maximum batch size for all ranks and |
357 |
| - features |
358 |
| - 6) output_size - The output size (number of elements) |
359 |
| - """ |
360 |
| - if batch_size_per_feature_per_rank is not None: |
361 |
| - assert optimizer in ( |
362 |
| - OptimType.EXACT_ROWWISE_ADAGRAD, |
363 |
| - OptimType.EXACT_SGD, |
364 |
| - OptimType.ENSEMBLE_ROWWISE_ADAGRAD, |
365 |
| - OptimType.NONE, |
366 |
| - ), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD and ENSEMBLE_ROWWISE_ADAGRAD only" |
367 |
| - assert ( |
368 |
| - pooling_mode != PoolingMode.NONE |
369 |
| - ), "Variable batch size TBE support is not enabled for PoolingMode.NONE" |
370 |
| - # TODO: Add input check |
371 |
| - zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32) |
372 |
| - |
373 |
| - # Create B offsets |
374 |
| - total_batch_size_per_feature = torch.tensor( |
375 |
| - batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu" |
376 |
| - ).sum(dim=1) |
377 |
| - |
378 |
| - max_B = total_batch_size_per_feature.max().item() |
379 |
| - if not torch.jit.is_scripting() and is_torchdynamo_compiling(): |
380 |
| - torch._check_is_size(max_B) |
381 |
| - torch._check(max_B < offsets.numel()) |
382 |
| - |
383 |
| - Bs = torch.concat([zero_tensor, total_batch_size_per_feature]) |
384 |
| - B_offsets = Bs.cumsum(dim=0).to(torch.int) |
385 |
| - |
386 |
| - # Create output offsets |
387 |
| - B_feature_rank = torch.tensor( |
388 |
| - batch_size_per_feature_per_rank, |
389 |
| - device="cpu", |
390 |
| - dtype=torch.int64, |
391 |
| - ) |
392 |
| - max_B_feature_rank = B_feature_rank.max().item() |
393 |
| - if not torch.jit.is_scripting() and is_torchdynamo_compiling(): |
394 |
| - torch._check_is_size(max_B_feature_rank) |
395 |
| - torch._check(max_B_feature_rank <= offsets.size(0)) |
396 |
| - output_sizes_feature_rank = B_feature_rank.transpose( |
397 |
| - 0, 1 |
398 |
| - ) * feature_dims_cpu.view(1, -1) |
399 |
| - output_offsets_feature_rank = torch.concat( |
400 |
| - [ |
401 |
| - zero_tensor.to(torch.int64), |
402 |
| - output_sizes_feature_rank.flatten().cumsum(dim=0), |
403 |
| - ] |
404 |
| - ) |
405 |
| - output_size = output_offsets_feature_rank[-1].item() |
406 |
| - if not torch.jit.is_scripting() and is_torchdynamo_compiling(): |
407 |
| - torch._check_is_size(output_size) |
408 |
| - |
409 |
| - # TODO: Support INT8 output |
410 |
| - # B_offsets_rank_per_feature is for rank and (b, t) mapping |
411 |
| - B_offsets_rank_per_feature = ( |
412 |
| - torch.tensor( |
413 |
| - [ |
414 |
| - [0] + batch_size_per_feature |
415 |
| - for batch_size_per_feature in batch_size_per_feature_per_rank |
416 |
| - ], |
417 |
| - device="cpu", |
418 |
| - dtype=torch.int32, |
419 |
| - ) |
420 |
| - .cumsum(dim=1) |
421 |
| - .to(torch.int) |
422 |
| - ) |
423 |
| - |
424 |
| - B_offsets = B_offsets.to(device, non_blocking=True) |
425 |
| - output_offsets_feature_rank = output_offsets_feature_rank.to( |
426 |
| - device, non_blocking=True |
427 |
| - ) |
428 |
| - B_offsets_rank_per_feature = B_offsets_rank_per_feature.to( |
429 |
| - device, non_blocking=True |
430 |
| - ) |
431 |
| - |
432 |
| - # TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank |
433 |
| - vbe_metadata = invokers.lookup_args.VBEMetadata( |
434 |
| - B_offsets=B_offsets, |
435 |
| - output_offsets_feature_rank=output_offsets_feature_rank, |
436 |
| - B_offsets_rank_per_feature=B_offsets_rank_per_feature, |
437 |
| - # pyre-ignore |
438 |
| - max_B=max_B, |
439 |
| - # pyre-ignore |
440 |
| - max_B_feature_rank=max_B_feature_rank, |
441 |
| - # pyre-ignore |
442 |
| - output_size=output_size, |
443 |
| - ) |
444 |
| - else: |
445 |
| - vbe_metadata = invokers.lookup_args.VBEMetadata( |
446 |
| - B_offsets=None, |
447 |
| - output_offsets_feature_rank=None, |
448 |
| - B_offsets_rank_per_feature=None, |
449 |
| - max_B=-1, |
450 |
| - max_B_feature_rank=-1, |
451 |
| - output_size=-1, |
452 |
| - ) |
453 |
| - return vbe_metadata |
454 |
| - |
455 |
| - |
456 | 324 | # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
|
457 | 325 | # pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
|
458 | 326 | class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
@@ -1379,6 +1247,17 @@ def _generate_vbe_metadata(
|
1379 | 1247 | ) -> invokers.lookup_args.VBEMetadata:
|
1380 | 1248 | # Blocking D2H copy, but only runs at first call
|
1381 | 1249 | self.feature_dims = self.feature_dims.cpu()
|
| 1250 | + if batch_size_per_feature_per_rank is not None: |
| 1251 | + assert self.optimizer in ( |
| 1252 | + OptimType.EXACT_ROWWISE_ADAGRAD, |
| 1253 | + OptimType.EXACT_SGD, |
| 1254 | + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, |
| 1255 | + OptimType.NONE, |
| 1256 | + ), ( |
| 1257 | + "Variable batch size TBE support is enabled for " |
| 1258 | + "OptimType.EXACT_ROWWISE_ADAGRAD and " |
| 1259 | + "ENSEMBLE_ROWWISE_ADAGRAD only" |
| 1260 | + ) |
1382 | 1261 | return generate_vbe_metadata(
|
1383 | 1262 | offsets,
|
1384 | 1263 | batch_size_per_feature_per_rank,
|
@@ -3043,6 +2922,17 @@ def _generate_vbe_metadata(
|
3043 | 2922 | ) -> invokers.lookup_args.VBEMetadata:
|
3044 | 2923 | # Blocking D2H copy, but only runs at first call
|
3045 | 2924 | self.feature_dims = self.feature_dims.cpu()
|
| 2925 | + if batch_size_per_feature_per_rank is not None: |
| 2926 | + assert self.optimizer in ( |
| 2927 | + OptimType.EXACT_ROWWISE_ADAGRAD, |
| 2928 | + OptimType.EXACT_SGD, |
| 2929 | + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, |
| 2930 | + OptimType.NONE, |
| 2931 | + ), ( |
| 2932 | + "Variable batch size TBE support is enabled for " |
| 2933 | + "OptimType.EXACT_ROWWISE_ADAGRAD and " |
| 2934 | + "ENSEMBLE_ROWWISE_ADAGRAD only" |
| 2935 | + ) |
3046 | 2936 | return generate_vbe_metadata(
|
3047 | 2937 | offsets,
|
3048 | 2938 | batch_size_per_feature_per_rank,
|
|
0 commit comments