Skip to content

Commit 9437e79

Browse files
committed
Introduce multi-operand collective permute
1 parent 66ad14f commit 9437e79

9 files changed

+220
-27
lines changed

test/test_mp_collective_permute.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,68 @@
55
import torch_xla.core.xla_model as xm
66

77

8+
def _test_single_tensor_collective_permute(device, world_size, ordinal, pairs):
9+
value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device)
10+
result_tensor = xm.collective_permute(value, pairs)
11+
12+
result = result_tensor.cpu().tolist()
13+
expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size - 1] * 100
14+
15+
if result != expected:
16+
print(f"Wrong result from core {ordinal}: {result}", file=sys.stderr)
17+
return False
18+
return True
19+
20+
21+
def _test_multi_tensor_collective_permute(device, world_size, ordinal, pairs):
22+
tensor1 = torch.tensor([ordinal] * 50, dtype=torch.int32, device=device)
23+
tensor2 = torch.tensor([ordinal + 100] * 75, dtype=torch.int32, device=device)
24+
tensor3 = torch.tensor(
25+
[ordinal + 200] * 25, dtype=torch.float32, device=device)
26+
27+
result_list = xm.collective_permute([tensor1, tensor2, tensor3], pairs)
28+
expected_ordinal = ordinal - 1 if ordinal != 0 else world_size - 1
29+
30+
result1 = result_list[0].cpu().tolist()
31+
expected1 = [expected_ordinal] * 50
32+
if result1 != expected1:
33+
print(f"Wrong result from core {ordinal}: {result1}", file=sys.stderr)
34+
return False
35+
36+
result2 = result_list[1].cpu().tolist()
37+
expected2 = [expected_ordinal + 100] * 75
38+
if result2 != expected2:
39+
print(f"Wrong result from core {ordinal}: {result2}", file=sys.stderr)
40+
return False
41+
42+
result3 = result_list[2].cpu().tolist()
43+
expected3 = [expected_ordinal + 200.0] * 25
44+
if result3 != expected3:
45+
print(f"Wrong result from core {ordinal}: {result3}", file=sys.stderr)
46+
return False
47+
48+
return True
49+
50+
851
def _mp_fn(index):
952
device = torch_xla.device()
1053
if xm.xla_device_hw(device) in ['TPU', 'NEURON']:
1154
world_size = xr.world_size()
1255
ordinal = xr.global_ordinal()
13-
value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device)
1456
pairs = []
1557
for i in range(1, world_size):
1658
pairs.append([i - 1, i])
1759
pairs.append([world_size - 1, 0])
18-
result_tensor = xm.collective_permute(value, pairs)
19-
20-
result = result_tensor.cpu().tolist()
21-
expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size - 1] * 100
22-
23-
if result != expected:
24-
print(f"Wrong result from core {ordinal}: {result}", file=sys.stderr)
60+
if not _test_single_tensor_collective_permute(device, world_size, ordinal,
61+
pairs):
62+
sys.exit(1)
63+
if not _test_multi_tensor_collective_permute(device, world_size, ordinal,
64+
pairs):
2565
sys.exit(1)
2666
else:
27-
print(f"Default device {device} is not a supported device", file=sys.stderr)
67+
print(
68+
f"Device {device} is not a supported device for this test",
69+
file=sys.stderr)
2870

2971

3072
if __name__ == '__main__':

torch_xla/core/xla_model.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,9 @@ def all_to_all(value: torch.Tensor,
744744
return result[0]
745745

746746

747-
def collective_permute(value: torch.Tensor,
748-
pairs: List[List[int]]) -> torch.Tensor:
747+
def collective_permute(
748+
tensors: Union[torch.Tensor, List[torch.Tensor]],
749+
pairs: List[List[int]]) -> Union[torch.Tensor, List[torch.Tensor]]:
749750
"""Performs a XLA `CollectivePermute()` operation on the input tensor.
750751
751752
WARNING: This function is not very reliable, may produce wrong results under
@@ -754,20 +755,25 @@ def collective_permute(value: torch.Tensor,
754755
See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute
755756
756757
Args:
757-
value (torch.Tensor): The input tensor.
758+
tensors: Either a single `torch.Tensor` or a list of `torch.Tensor` to
759+
perform the collective permute over.
758760
pairs (list): A list of (source_replica_id, target_replica_id) pairs,
759761
representing the sender and receiver for the `collective_permute()`
760762
operation. Example: `[[0, 1], [1, 2], [2, 0]]` defines three pairs. The
761763
tensor will be sent from replica 0 to replica 1, replica 1 to replica 2,
762764
and replica 2 to replica 0.
763765
764766
Returns:
765-
The result `torch.Tensor` of the `collective_permute()` operation.
767+
A single or list of `torch.Tensor` results of the `collective_permute()` operation.
766768
"""
769+
is_single_operand = isinstance(tensors, torch.Tensor)
770+
assert is_single_operand or (isinstance(tensors, list) and all(
771+
isinstance(v, torch.Tensor) for v in tensors))
772+
767773
token, devctx = _get_all_reduce_token()
768-
result = torch_xla._XLAC._xla_collective_permute(value, token, pairs)
769-
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
770-
return result[0]
774+
result = torch_xla._XLAC._xla_collective_permute(tensors, token, pairs)
775+
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
776+
return result[0] if is_single_operand else result[:-1]
771777

772778

773779
def collective_broadcast(tensors: List[torch.Tensor],

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,24 @@ CollectivePermuteResult BuildCollectivePermute(
378378
return {result, token_handler.GetNewToken(result)};
379379
}
380380

381+
MultiCollectivePermuteResult BuildCollectivePermute(
382+
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token,
383+
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
384+
TokenHandler token_handler(token);
385+
std::vector<xla::XlaOp> result(inputs.size());
386+
std::vector<xla::XlaOp> input_ops;
387+
for (const auto& input : inputs) {
388+
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
389+
input_ops.push_back(token_handler.GetInput(input, &input_shape));
390+
}
391+
xla::XlaOp collective_result =
392+
xla::MultiCollectivePermute(input_ops, source_target_pairs);
393+
for (size_t i = 0; i < inputs.size(); ++i) {
394+
result[i] = xla::GetTupleElement(collective_result, i);
395+
}
396+
return {result, token_handler.GetNewToken(result[0])};
397+
}
398+
381399
SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token,
382400
int64_t channel_id) {
383401
xla::ChannelHandle channel_handle;

torch_xla/csrc/cross_replica_reduces.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ struct ReduceScatterResultCoalesced {
6060
xla::XlaOp token;
6161
};
6262

63+
struct MultiCollectivePermuteResult {
64+
std::vector<xla::XlaOp> results;
65+
xla::XlaOp token;
66+
};
67+
6368
std::vector<xla::XlaOp> BuildAllReduce(
6469
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
6570
xla::XlaOp token, double scale,
@@ -90,6 +95,10 @@ CollectivePermuteResult BuildCollectivePermute(
9095
xla::XlaOp input, xla::XlaOp token,
9196
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
9297

98+
MultiCollectivePermuteResult BuildCollectivePermute(
99+
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token,
100+
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);
101+
93102
SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token,
94103
int64_t channel_id);
95104

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,8 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> AllToAll(
601601
std::tie(result, new_token) = tensor_methods::all_to_all(
602602
bridge::GetXlaTensor(input), *token, split_dimension, concat_dimension,
603603
split_count, replica_groups, pin_layout);
604-
return std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>>(
605-
bridge::AtenFromXlaTensor(std::move(result)),
606-
std::make_shared<torch::lazy::Value>(new_token));
604+
return {bridge::AtenFromXlaTensor(std::move(result)),
605+
std::make_shared<torch::lazy::Value>(new_token)};
607606
}
608607

609608
std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> CollectivePermute(
@@ -618,6 +617,24 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> CollectivePermute(
618617
std::make_shared<torch::lazy::Value>(new_token));
619618
}
620619

620+
std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
621+
CollectivePermute(
622+
const std::vector<at::Tensor>& tensors,
623+
const std::shared_ptr<torch::lazy::Value>& token,
624+
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
625+
std::vector<XLATensorPtr> xtensors =
626+
GetXlaTensors(tensors, /*want_all=*/true);
627+
std::vector<XLATensorPtr> results;
628+
torch::lazy::Value new_token;
629+
std::tie(results, new_token) =
630+
tensor_methods::collective_permute(xtensors, *token, source_target_pairs);
631+
std::vector<at::Tensor> aten_results;
632+
for (auto& xt : results) {
633+
aten_results.emplace_back(bridge::AtenFromXlaTensor(std::move(xt)));
634+
}
635+
return {aten_results, std::make_shared<torch::lazy::Value>(new_token)};
636+
}
637+
621638
void OptimizationBarrier_(std::vector<at::Tensor>& tensors) {
622639
std::vector<XLATensorPtr> xtensors =
623640
GetXlaTensors(tensors, /*want_all=*/false);
@@ -1990,6 +2007,27 @@ void InitXlaModuleBindings(py::module m) {
19902007
result_tuple[1] = new_token;
19912008
return result_tuple;
19922009
})
2010+
.def("_xla_collective_permute",
2011+
[](const std::vector<at::Tensor>& inputs,
2012+
const std::shared_ptr<torch::lazy::Value>& token,
2013+
const py::list& pairs) {
2014+
std::vector<std::pair<int64_t, int64_t>> source_target_pairs =
2015+
CreateSourceTargetPairs(pairs);
2016+
std::vector<at::Tensor> results;
2017+
std::shared_ptr<torch::lazy::Value> new_token;
2018+
{
2019+
NoGilSection nogil;
2020+
std::tie(results, new_token) =
2021+
CollectivePermute(inputs, token, source_target_pairs);
2022+
}
2023+
auto result_list = py::list(results.size() + 1);
2024+
for (int i = 0; i < results.size(); ++i) {
2025+
result_list[i] = torch::autograd::make_variable(
2026+
results[i], /*requires_grad=*/results[i].requires_grad());
2027+
}
2028+
result_list[results.size()] = new_token;
2029+
return result_list;
2030+
})
19932031
.def("_xla_send",
19942032
[](const at::Tensor& input,
19952033
const std::shared_ptr<torch::lazy::Value>& token,

torch_xla/csrc/ops/collective_permute.cpp

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,71 @@ CollectivePermute::CollectivePermute(
3131
/*num_outputs=*/2, torch::lazy::MHash(source_target_pairs)),
3232
source_target_pairs_(std::move(source_target_pairs)) {}
3333

34+
CollectivePermute::CollectivePermute(
35+
c10::ArrayRef<torch::lazy::Value> inputs, const torch::lazy::Value& token,
36+
std::vector<std::pair<int64_t, int64_t>> source_target_pairs)
37+
: XlaNode(
38+
xla_collective_permute, GetOperandListWithToken(inputs, token),
39+
[&]() {
40+
std::vector<xla::Shape> input_shapes;
41+
for (const auto& input : inputs) {
42+
input_shapes.push_back(GetXlaShape(input));
43+
}
44+
input_shapes.push_back(GetXlaShape(token));
45+
auto shape_fn =
46+
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
47+
std::vector<xla::XlaOp> input_ops(operands.begin(),
48+
operands.end() - 1);
49+
xla::XlaOp token_op = operands.back();
50+
MultiCollectivePermuteResult result = BuildCollectivePermute(
51+
input_ops, token_op, source_target_pairs);
52+
std::vector<xla::XlaOp> outputs = result.results;
53+
outputs.push_back(result.token);
54+
return xla::Tuple(operands[0].builder(), outputs);
55+
};
56+
return InferOutputShape(input_shapes, shape_fn);
57+
},
58+
/*num_outputs=*/inputs.size() + 1,
59+
torch::lazy::MHash(source_target_pairs)),
60+
source_target_pairs_(std::move(source_target_pairs)) {}
61+
3462
torch::lazy::NodePtr CollectivePermute::Clone(
3563
torch::lazy::OpList operands) const {
36-
return torch_xla::MakeNode<CollectivePermute>(operands.at(0), operands.at(1),
37-
source_target_pairs_);
64+
if (operands.size() > 2) {
65+
std::vector<torch::lazy::Value> inputs(operands.begin(),
66+
operands.end() - 1);
67+
return torch_xla::MakeNode<CollectivePermute>(inputs, operands.back(),
68+
source_target_pairs_);
69+
} else {
70+
return torch_xla::MakeNode<CollectivePermute>(
71+
operands.at(0), operands.at(1), source_target_pairs_);
72+
}
3873
}
3974

4075
XlaOpVector CollectivePermute::Lower(LoweringContext* loctx) const {
41-
xla::XlaOp input = loctx->GetOutputOp(operand(0));
42-
xla::XlaOp token = loctx->GetOutputOp(operand(1));
43-
CollectivePermuteResult result =
44-
BuildCollectivePermute(input, token, source_target_pairs_);
45-
return ReturnOps({result.result, result.token}, loctx);
76+
auto& operand_list = operands();
77+
size_t operand_list_size = operand_list.size();
78+
if (operand_list_size > 2) {
79+
std::vector<xla::XlaOp> inputs;
80+
inputs.reserve(operand_list_size);
81+
for (size_t i = 0; i < operand_list_size - 1; ++i) {
82+
inputs.push_back(loctx->GetOutputOp(operand(i)));
83+
}
84+
xla::XlaOp token = loctx->GetOutputOp(operand_list.back());
85+
86+
MultiCollectivePermuteResult result =
87+
BuildCollectivePermute(inputs, token, source_target_pairs_);
88+
89+
std::vector<xla::XlaOp> outputs = result.results;
90+
outputs.push_back(result.token);
91+
return ReturnOps(outputs, loctx);
92+
} else {
93+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
94+
xla::XlaOp token = loctx->GetOutputOp(operand(1));
95+
CollectivePermuteResult result =
96+
BuildCollectivePermute(input, token, source_target_pairs_);
97+
return ReturnOps({result.result, result.token}, loctx);
98+
}
4699
}
47100

48101
std::string CollectivePermute::ToString() const {

torch_xla/csrc/ops/collective_permute.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class CollectivePermute : public XlaNode {
1212
const torch::lazy::Value& input, const torch::lazy::Value& token,
1313
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);
1414

15+
CollectivePermute(
16+
c10::ArrayRef<torch::lazy::Value> inputs, const torch::lazy::Value& token,
17+
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);
18+
1519
std::string ToString() const override;
1620

1721
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
@@ -28,4 +32,4 @@ class CollectivePermute : public XlaNode {
2832

2933
} // namespace torch_xla
3034

31-
#endif // XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_
35+
#endif // XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_

torch_xla/csrc/tensor_methods.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,25 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
564564
torch::lazy::Value(node, 1)};
565565
}
566566

567+
std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> collective_permute(
568+
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
569+
std::vector<std::pair<int64_t, int64_t>> source_target_pairs) {
570+
std::vector<torch::lazy::Value> input_values;
571+
input_values.reserve(inputs.size());
572+
for (const auto& input : inputs) {
573+
input_values.push_back(input->GetIrValue());
574+
}
575+
torch::lazy::NodePtr node = torch_xla::MakeNode<CollectivePermute>(
576+
input_values, token, std::move(source_target_pairs));
577+
578+
std::vector<XLATensorPtr> result;
579+
result.reserve(inputs.size());
580+
for (size_t i = 0; i < inputs.size(); ++i) {
581+
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
582+
}
583+
return {result, torch::lazy::Value(node, inputs.size())};
584+
}
585+
567586
std::vector<XLATensorPtr> custom_call(
568587
const std::vector<XLATensorPtr>& inputs, const std::string& target,
569588
const std::vector<std::vector<int64_t>>& output_shapes,

torch_xla/csrc/tensor_methods.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
9090
const XLATensorPtr& input, const torch::lazy::Value& token,
9191
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);
9292

93+
std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> collective_permute(
94+
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
95+
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);
96+
9397
std::vector<XLATensorPtr> custom_call(
9498
const std::vector<XLATensorPtr>& inputs, const std::string& target,
9599
const std::vector<std::vector<int64_t>>& output_shapes,

0 commit comments

Comments
 (0)