@@ -39,6 +39,8 @@ namespace {
39
39
using ::testing::ElementsAre;
40
40
using ::testing::IsEmpty;
41
41
42
+ using AxisRefVector = SmallVector<AxisRefAttr>;
43
+
42
44
class UtilsTest : public ShardyTestBase {
43
45
protected:
44
46
void SetUp () override {
@@ -430,11 +432,10 @@ TEST_F(UtilsTest, GetAxisSetDiff) {
430
432
TEST_F (UtilsTest, SortAndMergeAxes) {
431
433
MeshAttr mesh = createMesh ({{" a" , 8 }, {" b" , 8 }, {" c" , 8 }, {" d" , 8 }});
432
434
433
- SmallVector<AxisRefAttr> axes = {
434
- createSubAxis (" a" , 1 , 2 ), createSubAxis (" a" , 4 , 2 ),
435
- createSubAxis (" b" , 1 , 2 ), createSubAxis (" b" , 2 , 2 ),
436
- createSubAxis (" b" , 4 , 2 ), createAxis (" c" ),
437
- createSubAxis (" d" , 2 , 2 ), createSubAxis (" d" , 4 , 2 )};
435
+ AxisRefVector axes = {createSubAxis (" a" , 1 , 2 ), createSubAxis (" a" , 4 , 2 ),
436
+ createSubAxis (" b" , 1 , 2 ), createSubAxis (" b" , 2 , 2 ),
437
+ createSubAxis (" b" , 4 , 2 ), createAxis (" c" ),
438
+ createSubAxis (" d" , 2 , 2 ), createSubAxis (" d" , 4 , 2 )};
438
439
439
440
// Shuffle the vector
440
441
std::random_device rd;
@@ -448,6 +449,68 @@ TEST_F(UtilsTest, SortAndMergeAxes) {
448
449
AxisRefIs (" c" ), SubAxisRefIs (" d" , 2 , 4 )));
449
450
}
450
451
452
+ TEST_F (UtilsTest, AddAxisOrMergeInserterSingleAxisNoMerge) {
453
+ MeshAttr mesh = createMesh ({{" a" , 8 }, {" b" , 8 }});
454
+ AxisRefVector axes = {createSubAxis (" a" , 1 , 2 ), createAxis (" b" )};
455
+ AxisRefVector newAxes;
456
+ std::transform (axes.begin (), axes.end (),
457
+ AddAxisOrMergeInserter (&newAxes, &mesh),
458
+ [&](AxisRefAttr axis) {
459
+ if (axis.getName () == " b" ) {
460
+ return createSubAxis (" c" , 4 , 2 );
461
+ }
462
+ return axis;
463
+ });
464
+ EXPECT_THAT (newAxes,
465
+ ElementsAre (SubAxisRefIs (" a" , 1 , 2 ), SubAxisRefIs (" c" , 4 , 2 )));
466
+ }
467
+
468
+ TEST_F (UtilsTest, AddAxisOrMergeInserterMultipleAxesNoMerge) {
469
+ MeshAttr mesh = createMesh ({{" a" , 8 }, {" b" , 8 }, {" c" , 8 }, {" d" , 8 }});
470
+ AxisRefVector axes = {createAxis (" a" ), createSubAxis (" b" , 1 , 2 )};
471
+ AxisRefVector newAxes;
472
+ std::transform (axes.begin (), axes.end (),
473
+ AddAxisOrMergeInserter (&newAxes, &mesh),
474
+ [&](AxisRefAttr axis) -> AxisRefVector {
475
+ if (axis.getName () == " b" ) {
476
+ return {createAxis (" c" ), createAxis (" d" )};
477
+ }
478
+ return {};
479
+ });
480
+ EXPECT_THAT (newAxes, ElementsAre (AxisRefIs (" c" ), AxisRefIs (" d" )));
481
+ }
482
+
483
+ TEST_F (UtilsTest, AddAxisOrMergeInserterSingleAxisMerge) {
484
+ MeshAttr mesh = createMesh ({{" a" , 8 }, {" b" , 8 }});
485
+ AxisRefVector axes = {createSubAxis (" a" , 1 , 2 ), createAxis (" b" )};
486
+ AxisRefVector newAxes;
487
+ std::transform (axes.begin (), axes.end (),
488
+ AddAxisOrMergeInserter (&newAxes, &mesh),
489
+ [&](AxisRefAttr axis) {
490
+ if (axis.getName () == " b" ) {
491
+ return createSubAxis (" a" , 2 , 2 );
492
+ }
493
+ return axis;
494
+ });
495
+ EXPECT_THAT (newAxes, ElementsAre (SubAxisRefIs (" a" , 1 , 4 )));
496
+ }
497
+
498
+ TEST_F (UtilsTest, AddAxisOrMergeInserterMultipleAxesMerge) {
499
+ MeshAttr mesh = createMesh ({{" a" , 8 }, {" b" , 16 }});
500
+ AxisRefVector axes = {createAxis (" a" ), createSubAxis (" b" , 8 , 2 )};
501
+ AxisRefVector newAxes;
502
+ std::transform (axes.begin (), axes.end (),
503
+ AddAxisOrMergeInserter (&newAxes, &mesh),
504
+ [&](AxisRefAttr axis) -> AxisRefVector {
505
+ if (axis.getName () == " a" ) {
506
+ return {createSubAxis (" b" , 2 , 2 ),
507
+ createSubAxis (" b" , 4 , 2 )};
508
+ }
509
+ return {axis};
510
+ });
511
+ EXPECT_THAT (newAxes, ElementsAre (SubAxisRefIs (" b" , 2 , 8 )));
512
+ }
513
+
451
514
} // namespace
452
515
453
516
} // namespace sdy
0 commit comments