@@ -176,31 +176,15 @@ def execute_forward_(
176
176
else managed [t ]
177
177
)
178
178
else :
179
- if use_cpu :
180
- managed = [split_table_batched_embeddings_ops .EmbeddingLocation .HOST ] * T
181
- compute_device = split_table_batched_embeddings_ops .ComputeDevice .CPU
182
- elif use_cache :
183
- managed = [
184
- split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED_CACHING
185
- ] * T
186
- if mixed :
187
- average_D = sum (Ds ) // T
188
- for t , d in enumerate (Ds ):
189
- managed [t ] = (
190
- split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE
191
- if d < average_D
192
- else managed [t ]
193
- )
194
- else :
195
- managed = [
196
- np .random .choice (
197
- [
198
- split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE ,
199
- split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED ,
200
- ]
201
- )
202
- for _ in range (T )
203
- ]
179
+ managed = [
180
+ np .random .choice (
181
+ [
182
+ split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE ,
183
+ split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED ,
184
+ ]
185
+ )
186
+ for _ in range (T )
187
+ ]
204
188
if do_pooling :
205
189
bs = [
206
190
to_device (torch .nn .EmbeddingBag (E , D , mode = mode , sparse = True ), use_cpu )
@@ -1368,34 +1352,26 @@ def execute_backward_sgd_( # noqa C901
1368
1352
managed = [split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE ] * T
1369
1353
elif use_cache :
1370
1354
managed = [
1371
- split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE
1355
+ split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED_CACHING
1372
1356
] * T
1373
- else :
1374
- if use_cpu :
1375
- managed = [split_table_batched_embeddings_ops .EmbeddingLocation .HOST ] * T
1376
- compute_device = split_table_batched_embeddings_ops .ComputeDevice .CPU
1377
- elif use_cache :
1378
- managed = [
1379
- split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED_CACHING
1380
- ] * T
1381
- if mixed :
1382
- average_D = sum (Ds ) // T
1383
- for t , d in enumerate (Ds ):
1384
- managed [t ] = (
1385
- split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE
1386
- if d < average_D
1387
- else managed [t ]
1388
- )
1389
- else :
1390
- managed = [
1391
- np .random .choice (
1392
- [
1393
- split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE ,
1394
- split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED ,
1395
- ]
1357
+ if mixed :
1358
+ average_D = sum (Ds ) // T
1359
+ for t , d in enumerate (Ds ):
1360
+ managed [t ] = (
1361
+ split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE
1362
+ if d < average_D
1363
+ else managed [t ]
1396
1364
)
1397
- for _ in range (T )
1398
- ]
1365
+ else :
1366
+ managed = [
1367
+ np .random .choice (
1368
+ [
1369
+ split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE ,
1370
+ split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED ,
1371
+ ]
1372
+ )
1373
+ for _ in range (T )
1374
+ ]
1399
1375
if do_pooling :
1400
1376
bs = [
1401
1377
to_device (torch .nn .EmbeddingBag (E , D , mode = mode , sparse = True ), use_cpu )
@@ -1727,34 +1703,26 @@ def execute_backward_adagrad_( # noqa C901
1727
1703
managed = [split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE ] * T
1728
1704
elif use_cache :
1729
1705
managed = [
1730
- split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE
1706
+ split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED_CACHING
1731
1707
] * T
1732
- else :
1733
- if use_cpu :
1734
- managed = [split_table_batched_embeddings_ops .EmbeddingLocation .HOST ] * T
1735
- compute_device = split_table_batched_embeddings_ops .ComputeDevice .CPU
1736
- elif use_cache :
1737
- managed = [
1738
- split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED_CACHING
1739
- ] * T
1740
- if mixed :
1741
- average_D = sum (Ds ) // T
1742
- for t , d in enumerate (Ds ):
1743
- managed [t ] = (
1744
- split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE
1745
- if d < average_D
1746
- else managed [t ]
1747
- )
1748
- else :
1749
- managed = [
1750
- np .random .choice (
1751
- [
1752
- split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE ,
1753
- split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED ,
1754
- ]
1708
+ if mixed :
1709
+ average_D = sum (Ds ) // T
1710
+ for t , d in enumerate (Ds ):
1711
+ managed [t ] = (
1712
+ split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE
1713
+ if d < average_D
1714
+ else managed [t ]
1755
1715
)
1756
- for _ in range (T )
1757
- ]
1716
+ else :
1717
+ managed = [
1718
+ np .random .choice (
1719
+ [
1720
+ split_table_batched_embeddings_ops .EmbeddingLocation .DEVICE ,
1721
+ split_table_batched_embeddings_ops .EmbeddingLocation .MANAGED ,
1722
+ ]
1723
+ )
1724
+ for _ in range (T )
1725
+ ]
1758
1726
if do_pooling :
1759
1727
bs = [
1760
1728
to_device (torch .nn .EmbeddingBag (E , D , mode = mode , sparse = True ), use_cpu )
0 commit comments