@@ -337,6 +337,7 @@ class AdjointGenerator
337
337
addToDiffe (orig_op0, dif1, Builder2, TR.addingType (size, orig_op0));
338
338
return ;
339
339
}
340
+ case DerivativeMode::ForwardModeSplit:
340
341
case DerivativeMode::ForwardMode: {
341
342
IRBuilder<> BuilderZ (&inst);
342
343
getForwardBuilder (BuilderZ);
@@ -383,6 +384,7 @@ class AdjointGenerator
383
384
dif1->getType ()->getScalarType ());
384
385
break ;
385
386
}
387
+ case DerivativeMode::ForwardModeSplit:
386
388
case DerivativeMode::ForwardMode: {
387
389
IRBuilder<> Builder2 (&inst);
388
390
getForwardBuilder (Builder2);
@@ -482,6 +484,7 @@ class AdjointGenerator
482
484
}
483
485
break ;
484
486
}
487
+ case DerivativeMode::ForwardModeSplit:
485
488
case DerivativeMode::ForwardMode: {
486
489
newip = gutils->invertPointerM (&I, BuilderZ);
487
490
assert (newip->getType () == type);
@@ -605,6 +608,7 @@ class AdjointGenerator
605
608
if (isfloat) {
606
609
607
610
switch (Mode) {
611
+ case DerivativeMode::ForwardModeSplit:
608
612
case DerivativeMode::ForwardMode: {
609
613
IRBuilder<> Builder2 (&I);
610
614
getForwardBuilder (Builder2);
@@ -889,6 +893,7 @@ class AdjointGenerator
889
893
}
890
894
break ;
891
895
}
896
+ case DerivativeMode::ForwardModeSplit:
892
897
case DerivativeMode::ForwardMode: {
893
898
IRBuilder<> Builder2 (&I);
894
899
getForwardBuilder (Builder2);
@@ -1056,6 +1061,7 @@ class AdjointGenerator
1056
1061
1057
1062
break ;
1058
1063
}
1064
+ case DerivativeMode::ForwardModeSplit:
1059
1065
case DerivativeMode::ForwardMode: {
1060
1066
IRBuilder<> Builder2 (&I);
1061
1067
getForwardBuilder (Builder2);
@@ -1096,6 +1102,7 @@ class AdjointGenerator
1096
1102
createSelectInstAdjoint (SI);
1097
1103
return ;
1098
1104
}
1105
+ case DerivativeMode::ForwardModeSplit:
1099
1106
case DerivativeMode::ForwardMode: {
1100
1107
createSelectInstDual (SI);
1101
1108
return ;
@@ -1247,6 +1254,7 @@ class AdjointGenerator
1247
1254
return ;
1248
1255
1249
1256
switch (Mode) {
1257
+ case DerivativeMode::ForwardModeSplit:
1250
1258
case DerivativeMode::ForwardMode: {
1251
1259
IRBuilder<> Builder2 (&EEI);
1252
1260
getForwardBuilder (Builder2);
@@ -1304,6 +1312,7 @@ class AdjointGenerator
1304
1312
return ;
1305
1313
1306
1314
switch (Mode) {
1315
+ case DerivativeMode::ForwardModeSplit:
1307
1316
case DerivativeMode::ForwardMode: {
1308
1317
IRBuilder<> Builder2 (&IEI);
1309
1318
getForwardBuilder (Builder2);
@@ -1387,6 +1396,7 @@ class AdjointGenerator
1387
1396
return ;
1388
1397
1389
1398
switch (Mode) {
1399
+ case DerivativeMode::ForwardModeSplit:
1390
1400
case DerivativeMode::ForwardMode: {
1391
1401
IRBuilder<> Builder2 (&SVI);
1392
1402
getForwardBuilder (Builder2);
@@ -1475,6 +1485,7 @@ class AdjointGenerator
1475
1485
return ;
1476
1486
1477
1487
switch (Mode) {
1488
+ case DerivativeMode::ForwardModeSplit:
1478
1489
case DerivativeMode::ForwardMode: {
1479
1490
IRBuilder<> Builder2 (&EVI);
1480
1491
getForwardBuilder (Builder2);
@@ -3453,6 +3464,7 @@ class AdjointGenerator
3453
3464
}
3454
3465
return ;
3455
3466
}
3467
+ case DerivativeMode::ForwardModeSplit:
3456
3468
case DerivativeMode::ForwardMode: {
3457
3469
3458
3470
IRBuilder<> Builder2 (&I);
@@ -7237,6 +7249,7 @@ class AdjointGenerator
7237
7249
return ;
7238
7250
7239
7251
switch (Mode) {
7252
+ case DerivativeMode::ForwardModeSplit:
7240
7253
case DerivativeMode::ForwardMode: {
7241
7254
IRBuilder<> Builder2 (&call);
7242
7255
getForwardBuilder (Builder2);
@@ -7296,6 +7309,7 @@ class AdjointGenerator
7296
7309
return ;
7297
7310
7298
7311
switch (Mode) {
7312
+ case DerivativeMode::ForwardModeSplit:
7299
7313
case DerivativeMode::ForwardMode: {
7300
7314
IRBuilder<> Builder2 (&call);
7301
7315
getForwardBuilder (Builder2);
@@ -7338,6 +7352,7 @@ class AdjointGenerator
7338
7352
return ;
7339
7353
7340
7354
switch (Mode) {
7355
+ case DerivativeMode::ForwardModeSplit:
7341
7356
case DerivativeMode::ForwardMode: {
7342
7357
IRBuilder<> Builder2 (&call);
7343
7358
getForwardBuilder (Builder2);
@@ -7402,6 +7417,7 @@ class AdjointGenerator
7402
7417
return ;
7403
7418
7404
7419
switch (Mode) {
7420
+ case DerivativeMode::ForwardModeSplit:
7405
7421
case DerivativeMode::ForwardMode: {
7406
7422
IRBuilder<> Builder2 (&call);
7407
7423
getForwardBuilder (Builder2);
@@ -7455,6 +7471,7 @@ class AdjointGenerator
7455
7471
return ;
7456
7472
7457
7473
switch (Mode) {
7474
+ case DerivativeMode::ForwardModeSplit:
7458
7475
case DerivativeMode::ForwardMode: {
7459
7476
IRBuilder<> Builder2 (&call);
7460
7477
getForwardBuilder (Builder2);
@@ -7505,6 +7522,7 @@ class AdjointGenerator
7505
7522
return ;
7506
7523
7507
7524
switch (Mode) {
7525
+ case DerivativeMode::ForwardModeSplit:
7508
7526
case DerivativeMode::ForwardMode: {
7509
7527
IRBuilder<> Builder2 (&call);
7510
7528
getForwardBuilder (Builder2);
@@ -7828,6 +7846,7 @@ class AdjointGenerator
7828
7846
return ;
7829
7847
7830
7848
switch (Mode) {
7849
+ case DerivativeMode::ForwardModeSplit:
7831
7850
case DerivativeMode::ForwardMode: {
7832
7851
IRBuilder<> Builder2 (&call);
7833
7852
getForwardBuilder (Builder2);
@@ -7884,6 +7903,7 @@ class AdjointGenerator
7884
7903
return ;
7885
7904
7886
7905
switch (Mode) {
7906
+ case DerivativeMode::ForwardModeSplit:
7887
7907
case DerivativeMode::ForwardMode: {
7888
7908
IRBuilder<> Builder2 (&call);
7889
7909
getForwardBuilder (Builder2);
@@ -7964,6 +7984,7 @@ class AdjointGenerator
7964
7984
return ;
7965
7985
7966
7986
switch (Mode) {
7987
+ case DerivativeMode::ForwardModeSplit:
7967
7988
case DerivativeMode::ForwardMode: {
7968
7989
IRBuilder<> Builder2 (&call);
7969
7990
getForwardBuilder (Builder2);
@@ -8124,6 +8145,7 @@ class AdjointGenerator
8124
8145
}
8125
8146
8126
8147
switch (Mode) {
8148
+ case DerivativeMode::ForwardModeSplit:
8127
8149
case DerivativeMode::ForwardMode: {
8128
8150
IRBuilder<> Builder2 (&call);
8129
8151
getForwardBuilder (Builder2);
@@ -8199,6 +8221,7 @@ class AdjointGenerator
8199
8221
}
8200
8222
8201
8223
switch (Mode) {
8224
+ case DerivativeMode::ForwardModeSplit:
8202
8225
case DerivativeMode::ForwardMode: {
8203
8226
IRBuilder<> Builder2 (&call);
8204
8227
getForwardBuilder (Builder2);
@@ -8282,6 +8305,7 @@ class AdjointGenerator
8282
8305
}
8283
8306
8284
8307
switch (Mode) {
8308
+ case DerivativeMode::ForwardModeSplit:
8285
8309
case DerivativeMode::ForwardMode: {
8286
8310
IRBuilder<> Builder2 (&call);
8287
8311
getForwardBuilder (Builder2);
@@ -8348,12 +8372,10 @@ class AdjointGenerator
8348
8372
Mode == DerivativeMode::ReverseModePrimal) {
8349
8373
8350
8374
bool backwardsShadow = false ;
8351
- bool forwardsShadow = true ;
8352
8375
{
8353
8376
auto found = gutils->backwardsOnlyShadows .find (orig);
8354
8377
if (found != gutils->backwardsOnlyShadows .end ()) {
8355
8378
backwardsShadow = true ;
8356
- forwardsShadow = found->second .second ;
8357
8379
}
8358
8380
}
8359
8381
0 commit comments