Skip to content

Commit a25615e

Browse files
committed
resolve confilicts
1 parent d41f042 commit a25615e

File tree

1 file changed

+45
-77
lines changed

1 file changed

+45
-77
lines changed

fbgemm_gpu/test/split_table_batched_embeddings_test.py

Lines changed: 45 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -176,31 +176,15 @@ def execute_forward_(
176176
else managed[t]
177177
)
178178
else:
179-
if use_cpu:
180-
managed = [split_table_batched_embeddings_ops.EmbeddingLocation.HOST] * T
181-
compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
182-
elif use_cache:
183-
managed = [
184-
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING
185-
] * T
186-
if mixed:
187-
average_D = sum(Ds) // T
188-
for t, d in enumerate(Ds):
189-
managed[t] = (
190-
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
191-
if d < average_D
192-
else managed[t]
193-
)
194-
else:
195-
managed = [
196-
np.random.choice(
197-
[
198-
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
199-
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
200-
]
201-
)
202-
for _ in range(T)
203-
]
179+
managed = [
180+
np.random.choice(
181+
[
182+
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
183+
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
184+
]
185+
)
186+
for _ in range(T)
187+
]
204188
if do_pooling:
205189
bs = [
206190
to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)
@@ -1368,34 +1352,26 @@ def execute_backward_sgd_( # noqa C901
13681352
managed = [split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE] * T
13691353
elif use_cache:
13701354
managed = [
1371-
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
1355+
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING
13721356
] * T
1373-
else:
1374-
if use_cpu:
1375-
managed = [split_table_batched_embeddings_ops.EmbeddingLocation.HOST] * T
1376-
compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
1377-
elif use_cache:
1378-
managed = [
1379-
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING
1380-
] * T
1381-
if mixed:
1382-
average_D = sum(Ds) // T
1383-
for t, d in enumerate(Ds):
1384-
managed[t] = (
1385-
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
1386-
if d < average_D
1387-
else managed[t]
1388-
)
1389-
else:
1390-
managed = [
1391-
np.random.choice(
1392-
[
1393-
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
1394-
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
1395-
]
1357+
if mixed:
1358+
average_D = sum(Ds) // T
1359+
for t, d in enumerate(Ds):
1360+
managed[t] = (
1361+
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
1362+
if d < average_D
1363+
else managed[t]
13961364
)
1397-
for _ in range(T)
1398-
]
1365+
else:
1366+
managed = [
1367+
np.random.choice(
1368+
[
1369+
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
1370+
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
1371+
]
1372+
)
1373+
for _ in range(T)
1374+
]
13991375
if do_pooling:
14001376
bs = [
14011377
to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)
@@ -1727,34 +1703,26 @@ def execute_backward_adagrad_( # noqa C901
17271703
managed = [split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE] * T
17281704
elif use_cache:
17291705
managed = [
1730-
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
1706+
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING
17311707
] * T
1732-
else:
1733-
if use_cpu:
1734-
managed = [split_table_batched_embeddings_ops.EmbeddingLocation.HOST] * T
1735-
compute_device = split_table_batched_embeddings_ops.ComputeDevice.CPU
1736-
elif use_cache:
1737-
managed = [
1738-
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING
1739-
] * T
1740-
if mixed:
1741-
average_D = sum(Ds) // T
1742-
for t, d in enumerate(Ds):
1743-
managed[t] = (
1744-
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
1745-
if d < average_D
1746-
else managed[t]
1747-
)
1748-
else:
1749-
managed = [
1750-
np.random.choice(
1751-
[
1752-
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
1753-
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
1754-
]
1708+
if mixed:
1709+
average_D = sum(Ds) // T
1710+
for t, d in enumerate(Ds):
1711+
managed[t] = (
1712+
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
1713+
if d < average_D
1714+
else managed[t]
17551715
)
1756-
for _ in range(T)
1757-
]
1716+
else:
1717+
managed = [
1718+
np.random.choice(
1719+
[
1720+
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
1721+
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
1722+
]
1723+
)
1724+
for _ in range(T)
1725+
]
17581726
if do_pooling:
17591727
bs = [
17601728
to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)

0 commit comments

Comments
 (0)