Skip to content

Commit 4e06547

Browse files
#sdy Move AddAxisOrMergeInserter to Shady OSS utils.
PiperOrigin-RevId: 776780231
1 parent 7087562 commit 4e06547

File tree

2 files changed

+106
-5
lines changed

2 files changed

+106
-5
lines changed

shardy/dialect/sdy/ir/utils.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define SHARDY_DIALECT_SDY_IR_UTILS_H_
1818

1919
#include <cstdint>
20+
#include <iterator>
2021
#include <optional>
2122
#include <string>
2223
#include <utility>
@@ -498,6 +499,43 @@ bool hasOnlyUsersOfType(Operation* op) {
498499
});
499500
}
500501

502+
// Helper class to insert or merge `AxisRefAttr`s into a `SmallVector`.
503+
// It implements the interface required by `llvm::transform` to be used as an
504+
// output iterator.
505+
class AddAxisOrMergeInserter {
506+
public:
507+
using iterator_category = std::output_iterator_tag;
508+
using value_type = void;
509+
using difference_type = void;
510+
using pointer = void;
511+
using reference = void;
512+
513+
explicit AddAxisOrMergeInserter(SmallVector<AxisRefAttr>* newAxisRefs,
514+
const MeshAttr* mesh)
515+
: axisRefs(newAxisRefs), mesh(mesh) {}
516+
517+
AddAxisOrMergeInserter& operator=(AxisRefAttr value) {
518+
sdy::addAxisOrMerge(*axisRefs, value, *mesh);
519+
return *this;
520+
}
521+
522+
AddAxisOrMergeInserter& operator=(ArrayRef<AxisRefAttr> values) {
523+
for (AxisRefAttr value : values) {
524+
*this = value;
525+
}
526+
return *this;
527+
}
528+
529+
AddAxisOrMergeInserter& operator*() { return *this; }
530+
AddAxisOrMergeInserter& operator++() { return *this; }
531+
AddAxisOrMergeInserter& operator++(int) { return *this; }
532+
533+
private:
534+
// Use pointers so that callers like `llvm::transform` can copy the inserter.
535+
SmallVector<AxisRefAttr>* axisRefs;
536+
const MeshAttr* mesh;
537+
};
538+
501539
} // namespace sdy
502540
} // namespace mlir
503541

shardy/dialect/sdy/ir/utils_test.cc

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ namespace {
3939
using ::testing::ElementsAre;
4040
using ::testing::IsEmpty;
4141

42+
using AxisRefVector = SmallVector<AxisRefAttr>;
43+
4244
class UtilsTest : public ShardyTestBase {
4345
protected:
4446
void SetUp() override {
@@ -430,11 +432,10 @@ TEST_F(UtilsTest, GetAxisSetDiff) {
430432
TEST_F(UtilsTest, SortAndMergeAxes) {
431433
MeshAttr mesh = createMesh({{"a", 8}, {"b", 8}, {"c", 8}, {"d", 8}});
432434

433-
SmallVector<AxisRefAttr> axes = {
434-
createSubAxis("a", 1, 2), createSubAxis("a", 4, 2),
435-
createSubAxis("b", 1, 2), createSubAxis("b", 2, 2),
436-
createSubAxis("b", 4, 2), createAxis("c"),
437-
createSubAxis("d", 2, 2), createSubAxis("d", 4, 2)};
435+
AxisRefVector axes = {createSubAxis("a", 1, 2), createSubAxis("a", 4, 2),
436+
createSubAxis("b", 1, 2), createSubAxis("b", 2, 2),
437+
createSubAxis("b", 4, 2), createAxis("c"),
438+
createSubAxis("d", 2, 2), createSubAxis("d", 4, 2)};
438439

439440
// Shuffle the vector
440441
std::random_device rd;
@@ -448,6 +449,68 @@ TEST_F(UtilsTest, SortAndMergeAxes) {
448449
AxisRefIs("c"), SubAxisRefIs("d", 2, 4)));
449450
}
450451

452+
TEST_F(UtilsTest, AddAxisOrMergeInserterSingleAxisNoMerge) {
453+
MeshAttr mesh = createMesh({{"a", 8}, {"b", 8}});
454+
AxisRefVector axes = {createSubAxis("a", 1, 2), createAxis("b")};
455+
AxisRefVector newAxes;
456+
std::transform(axes.begin(), axes.end(),
457+
AddAxisOrMergeInserter(&newAxes, &mesh),
458+
[&](AxisRefAttr axis) {
459+
if (axis.getName() == "b") {
460+
return createSubAxis("c", 4, 2);
461+
}
462+
return axis;
463+
});
464+
EXPECT_THAT(newAxes,
465+
ElementsAre(SubAxisRefIs("a", 1, 2), SubAxisRefIs("c", 4, 2)));
466+
}
467+
468+
TEST_F(UtilsTest, AddAxisOrMergeInserterMultipleAxesNoMerge) {
469+
MeshAttr mesh = createMesh({{"a", 8}, {"b", 8}, {"c", 8}, {"d", 8}});
470+
AxisRefVector axes = {createAxis("a"), createSubAxis("b", 1, 2)};
471+
AxisRefVector newAxes;
472+
std::transform(axes.begin(), axes.end(),
473+
AddAxisOrMergeInserter(&newAxes, &mesh),
474+
[&](AxisRefAttr axis) -> AxisRefVector {
475+
if (axis.getName() == "b") {
476+
return {createAxis("c"), createAxis("d")};
477+
}
478+
return {};
479+
});
480+
EXPECT_THAT(newAxes, ElementsAre(AxisRefIs("c"), AxisRefIs("d")));
481+
}
482+
483+
TEST_F(UtilsTest, AddAxisOrMergeInserterSingleAxisMerge) {
484+
MeshAttr mesh = createMesh({{"a", 8}, {"b", 8}});
485+
AxisRefVector axes = {createSubAxis("a", 1, 2), createAxis("b")};
486+
AxisRefVector newAxes;
487+
std::transform(axes.begin(), axes.end(),
488+
AddAxisOrMergeInserter(&newAxes, &mesh),
489+
[&](AxisRefAttr axis) {
490+
if (axis.getName() == "b") {
491+
return createSubAxis("a", 2, 2);
492+
}
493+
return axis;
494+
});
495+
EXPECT_THAT(newAxes, ElementsAre(SubAxisRefIs("a", 1, 4)));
496+
}
497+
498+
TEST_F(UtilsTest, AddAxisOrMergeInserterMultipleAxesMerge) {
499+
MeshAttr mesh = createMesh({{"a", 8}, {"b", 16}});
500+
AxisRefVector axes = {createAxis("a"), createSubAxis("b", 8, 2)};
501+
AxisRefVector newAxes;
502+
std::transform(axes.begin(), axes.end(),
503+
AddAxisOrMergeInserter(&newAxes, &mesh),
504+
[&](AxisRefAttr axis) -> AxisRefVector {
505+
if (axis.getName() == "a") {
506+
return {createSubAxis("b", 2, 2),
507+
createSubAxis("b", 4, 2)};
508+
}
509+
return {axis};
510+
});
511+
EXPECT_THAT(newAxes, ElementsAre(SubAxisRefIs("b", 2, 8)));
512+
}
513+
451514
} // namespace
452515

453516
} // namespace sdy

0 commit comments

Comments
 (0)