@@ -76,6 +76,10 @@ def CmpLT : PatLeaf<(i32 2)>;
76
76
def CmpLE : PatLeaf<(i32 3)>;
77
77
def CmpGT : PatLeaf<(i32 4)>;
78
78
def CmpGE : PatLeaf<(i32 5)>;
79
+ def CmpLO : PatLeaf<(i32 6)>;
80
+ def CmpLS : PatLeaf<(i32 7)>;
81
+ def CmpHI : PatLeaf<(i32 8)>;
82
+ def CmpHS : PatLeaf<(i32 9)>;
79
83
def CmpEQU : PatLeaf<(i32 10)>;
80
84
def CmpNEU : PatLeaf<(i32 11)>;
81
85
def CmpLTU : PatLeaf<(i32 12)>;
@@ -107,6 +111,21 @@ def VecElement : Operand<i32> {
107
111
let PrintMethod = "printVecElement";
108
112
}
109
113
114
+ // PRMT modes
115
+ // These must match the enum in NVPTX.h
116
+ def PrmtNONE : PatLeaf<(i32 0x0)>;
117
+ def PrmtF4E : PatLeaf<(i32 0x1)>;
118
+ def PrmtB4E : PatLeaf<(i32 0x2)>;
119
+ def PrmtRC8 : PatLeaf<(i32 0x3)>;
120
+ def PrmtECL : PatLeaf<(i32 0x4)>;
121
+ def PrmtECR : PatLeaf<(i32 0x5)>;
122
+ def PrmtRC16 : PatLeaf<(i32 0x6)>;
123
+
124
+ def PrmtMode : Operand<i32> {
125
+ let PrintMethod = "printPrmtMode";
126
+ }
127
+
128
+
110
129
//===----------------------------------------------------------------------===//
111
130
// NVPTX Instruction Predicate Definitions
112
131
//===----------------------------------------------------------------------===//
@@ -742,7 +761,7 @@ defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
742
761
// def v2f16imm : Operand<v2f16>;
743
762
// defm SELP_f16x2 : SELP_PATTERN<"b32", v2f16, Int32Regs, v2f16imm, imm>;
744
763
745
- foreach vt = [v2f16, v2bf16, v2i16] in {
764
+ foreach vt = [v2f16, v2bf16, v2i16, v4i8 ] in {
746
765
def : Pat<(vt (select Int1Regs:$p, (vt Int32Regs:$a), (vt Int32Regs:$b))),
747
766
(SELP_b32rr Int32Regs:$a, Int32Regs:$b, Int1Regs:$p)>;
748
767
}
@@ -1486,23 +1505,24 @@ defm OR : BITWISE<"or", or>;
1486
1505
defm AND : BITWISE<"and", and>;
1487
1506
defm XOR : BITWISE<"xor", xor>;
1488
1507
1489
- // Lower logical v2i16 ops as bitwise ops on b32.
1490
- def: Pat<(or (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
1491
- (ORb32rr Int32Regs:$a, Int32Regs:$b)>;
1492
- def: Pat<(xor (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
1493
- (XORb32rr Int32Regs:$a, Int32Regs:$b)>;
1494
- def: Pat<(and (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
1495
- (ANDb32rr Int32Regs:$a, Int32Regs:$b)>;
1496
-
1497
- // The constants get legalized into a bitcast from i32, so that's what we need
1498
- // to match here.
1499
- def: Pat<(or Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
1500
- (ORb32ri Int32Regs:$a, imm:$b)>;
1501
- def: Pat<(xor Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
1502
- (XORb32ri Int32Regs:$a, imm:$b)>;
1503
- def: Pat<(and Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
1504
- (ANDb32ri Int32Regs:$a, imm:$b)>;
1505
-
1508
+ // Lower logical v2i16/v4i8 ops as bitwise ops on b32.
1509
+ foreach vt = [v2i16, v4i8] in {
1510
+ def: Pat<(or (vt Int32Regs:$a), (vt Int32Regs:$b)),
1511
+ (ORb32rr Int32Regs:$a, Int32Regs:$b)>;
1512
+ def: Pat<(xor (vt Int32Regs:$a), (vt Int32Regs:$b)),
1513
+ (XORb32rr Int32Regs:$a, Int32Regs:$b)>;
1514
+ def: Pat<(and (vt Int32Regs:$a), (vt Int32Regs:$b)),
1515
+ (ANDb32rr Int32Regs:$a, Int32Regs:$b)>;
1516
+
1517
+ // The constants get legalized into a bitcast from i32, so that's what we need
1518
+ // to match here.
1519
+ def: Pat<(or Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
1520
+ (ORb32ri Int32Regs:$a, imm:$b)>;
1521
+ def: Pat<(xor Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
1522
+ (XORb32ri Int32Regs:$a, imm:$b)>;
1523
+ def: Pat<(and Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
1524
+ (ANDb32ri Int32Regs:$a, imm:$b)>;
1525
+ }
1506
1526
1507
1527
def NOT1 : NVPTXInst<(outs Int1Regs:$dst), (ins Int1Regs:$src),
1508
1528
"not.pred \t$dst, $src;",
@@ -1737,34 +1757,119 @@ def FUNSHFRCLAMP :
1737
1757
// restriction in PTX?
1738
1758
//
1739
1759
// dest and src may be int32 or int64, but start and end are always int32.
1740
- multiclass BFX<string Instr, RegisterClass RC> {
1760
+ def SDTBFE :
1761
+ SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>,
1762
+ SDTCisVT<2, i32>, SDTCisVT<3, i32>]>;
1763
+ def bfe : SDNode<"NVPTXISD::BFE", SDTBFE>;
1764
+
1765
+ def SDTBFI :
1766
+ SDTypeProfile<1, 4, [SDTCisInt<0>, SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>,
1767
+ SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
1768
+ def bfi : SDNode<"NVPTXISD::BFI", SDTBFI>;
1769
+
1770
+ def SDTPRMT :
1771
+ SDTypeProfile<1, 4, [SDTCisVT<0, i32>, SDTCisVT<1, i32>,
1772
+ SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>,]>;
1773
+ def prmt : SDNode<"NVPTXISD::PRMT", SDTPRMT>;
1774
+
1775
+ multiclass BFE<string Instr, ValueType T, RegisterClass RC> {
1741
1776
def rrr
1742
1777
: NVPTXInst<(outs RC:$d),
1743
1778
(ins RC:$a, Int32Regs:$b, Int32Regs:$c),
1744
- !strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
1779
+ !strconcat(Instr, " \t$d, $a, $b, $c;"),
1780
+ [(set (T RC:$d), (bfe (T RC:$a), (i32 Int32Regs:$b), (i32 Int32Regs:$c)))]>;
1745
1781
def rri
1746
1782
: NVPTXInst<(outs RC:$d),
1747
1783
(ins RC:$a, Int32Regs:$b, i32imm:$c),
1748
- !strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
1784
+ !strconcat(Instr, " \t$d, $a, $b, $c;"),
1785
+ [(set (T RC:$d), (bfe (T RC:$a), (i32 Int32Regs:$b), (i32 imm:$c)))]>;
1749
1786
def rii
1750
1787
: NVPTXInst<(outs RC:$d),
1751
1788
(ins RC:$a, i32imm:$b, i32imm:$c),
1752
- !strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
1789
+ !strconcat(Instr, " \t$d, $a, $b, $c;"),
1790
+ [(set (T RC:$d), (bfe (T RC:$a), (i32 imm:$b), (i32 imm:$c)))]>;
1791
+ }
1792
+
1793
+ multiclass BFI<string Instr, ValueType T, RegisterClass RC, Operand ImmCls> {
1794
+ def rrrr
1795
+ : NVPTXInst<(outs RC:$f),
1796
+ (ins RC:$a, RC:$b, Int32Regs:$c, Int32Regs:$d),
1797
+ !strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1798
+ [(set (T RC:$f), (bfi (T RC:$a), (T RC:$b), (i32 Int32Regs:$c), (i32 Int32Regs:$d)))]>;
1799
+ def rrri
1800
+ : NVPTXInst<(outs RC:$f),
1801
+ (ins RC:$a, RC:$b, Int32Regs:$c, i32imm:$d),
1802
+ !strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1803
+ [(set (T RC:$f), (bfi (T RC:$a), (T RC:$b), (i32 Int32Regs:$c), (i32 imm:$d)))]>;
1804
+ def rrii
1805
+ : NVPTXInst<(outs RC:$f),
1806
+ (ins RC:$a, RC:$b, i32imm:$c, i32imm:$d),
1807
+ !strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1808
+ [(set (T RC:$f), (bfi (T RC:$a), (T RC:$b), (i32 imm:$c), (i32 imm:$d)))]>;
1809
+ def irrr
1810
+ : NVPTXInst<(outs RC:$f),
1811
+ (ins ImmCls:$a, RC:$b, Int32Regs:$c, Int32Regs:$d),
1812
+ !strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1813
+ [(set (T RC:$f), (bfi (T imm:$a), (T RC:$b), (i32 Int32Regs:$c), (i32 Int32Regs:$d)))]>;
1814
+ def irri
1815
+ : NVPTXInst<(outs RC:$f),
1816
+ (ins ImmCls:$a, RC:$b, Int32Regs:$c, i32imm:$d),
1817
+ !strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1818
+ [(set (T RC:$f), (bfi (T imm:$a), (T RC:$b), (i32 Int32Regs:$c), (i32 imm:$d)))]>;
1819
+ def irii
1820
+ : NVPTXInst<(outs RC:$f),
1821
+ (ins ImmCls:$a, RC:$b, i32imm:$c, i32imm:$d),
1822
+ !strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1823
+ [(set (T RC:$f), (bfi (T imm:$a), (T RC:$b), (i32 imm:$c), (i32 imm:$d)))]>;
1824
+ }
1825
+
1826
+ multiclass PRMT<ValueType T, RegisterClass RC> {
1827
+ def rrr
1828
+ : NVPTXInst<(outs RC:$d),
1829
+ (ins RC:$a, Int32Regs:$b, Int32Regs:$c, PrmtMode:$mode),
1830
+ !strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1831
+ [(set (T RC:$d), (prmt (T RC:$a), (T RC:$b), (i32 Int32Regs:$c), imm:$mode))]>;
1832
+ def rri
1833
+ : NVPTXInst<(outs RC:$d),
1834
+ (ins RC:$a, Int32Regs:$b, i32imm:$c, PrmtMode:$mode),
1835
+ !strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1836
+ [(set (T RC:$d), (prmt (T RC:$a), (T RC:$b), (i32 imm:$c), imm:$mode))]>;
1837
+ def rii
1838
+ : NVPTXInst<(outs RC:$d),
1839
+ (ins RC:$a, i32imm:$b, i32imm:$c, PrmtMode:$mode),
1840
+ !strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1841
+ [(set (T RC:$d), (prmt (T RC:$a), (T imm:$b), (i32 imm:$c), imm:$mode))]>;
1753
1842
}
1754
1843
1755
1844
let hasSideEffects = false in {
1756
- defm BFE_S32 : BFX<"bfe.s32", Int32Regs>;
1757
- defm BFE_U32 : BFX<"bfe.u32", Int32Regs>;
1758
- defm BFE_S64 : BFX<"bfe.s64", Int64Regs>;
1759
- defm BFE_U64 : BFX<"bfe.u64", Int64Regs>;
1845
+ defm BFE_S32 : BFE<"bfe.s32", i32, Int32Regs>;
1846
+ defm BFE_U32 : BFE<"bfe.u32", i32, Int32Regs>;
1847
+ defm BFE_S64 : BFE<"bfe.s64", i64, Int64Regs>;
1848
+ defm BFE_U64 : BFE<"bfe.u64", i64, Int64Regs>;
1849
+
1850
+ defm BFI_B32 : BFI<"bfi.b32", i32, Int32Regs, i32imm>;
1851
+ defm BFI_B64 : BFI<"bfi.b64", i64, Int64Regs, i64imm>;
1760
1852
1761
- defm BFI_S32 : BFX<"bfi.s32", Int32Regs>;
1762
- defm BFI_U32 : BFX<"bfi.u32", Int32Regs>;
1763
- defm BFI_S64 : BFX<"bfi.s64", Int64Regs>;
1764
- defm BFI_U64 : BFX<"bfi.u64", Int64Regs>;
1853
+ defm PRMT_B32 : PRMT<i32, Int32Regs>;
1765
1854
}
1766
1855
1767
- // Common byte extraction patterns
1856
+
1857
+ // byte extraction + signed/unsigned extension to i32.
1858
+ def : Pat<(i32 (sext_inreg (bfe (i32 Int32Regs:$s), (i32 Int32Regs:$o), 8), i8)),
1859
+ (BFE_S32rri Int32Regs:$s, Int32Regs:$o, 8)>;
1860
+ def : Pat<(i32 (sext_inreg (bfe (i32 Int32Regs:$s), (i32 imm:$o), 8), i8)),
1861
+ (BFE_S32rii Int32Regs:$s, imm:$o, 8)>;
1862
+ def : Pat<(i32 (and (bfe (i32 Int32Regs:$s), (i32 Int32Regs:$o), 8), 255)),
1863
+ (BFE_U32rri Int32Regs:$s, Int32Regs:$o, 8)>;
1864
+ def : Pat<(i32 (and (bfe (i32 Int32Regs:$s), (i32 imm:$o), 8), 255)),
1865
+ (BFE_U32rii Int32Regs:$s, imm:$o, 8)>;
1866
+
1867
+ // byte extraction + signed extension to i16
1868
+ def : Pat<(i16 (sext_inreg (trunc (bfe (i32 Int32Regs:$s), (i32 imm:$o), 8)), i8)),
1869
+ (CVT_s8_s32 (BFE_S32rii Int32Regs:$s, imm:$o, 8), CvtNONE)>;
1870
+
1871
+
1872
+ // Byte extraction via shift/trunc/sext
1768
1873
def : Pat<(i16 (sext_inreg (trunc Int32Regs:$s), i8)),
1769
1874
(CVT_s8_s32 Int32Regs:$s, CvtNONE)>;
1770
1875
def : Pat<(i16 (sext_inreg (trunc (srl (i32 Int32Regs:$s), (i32 imm:$o))), i8)),
@@ -1773,7 +1878,6 @@ def : Pat<(sext_inreg (srl (i32 Int32Regs:$s), (i32 imm:$o)), i8),
1773
1878
(BFE_S32rii Int32Regs:$s, imm:$o, 8)>;
1774
1879
def : Pat<(i16 (sra (i16 (trunc Int32Regs:$s)), (i32 8))),
1775
1880
(CVT_s8_s32 (BFE_S32rii Int32Regs:$s, 8, 8), CvtNONE)>;
1776
-
1777
1881
def : Pat<(sext_inreg (srl (i64 Int64Regs:$s), (i32 imm:$o)), i8),
1778
1882
(BFE_S64rii Int64Regs:$s, imm:$o, 8)>;
1779
1883
def : Pat<(i16 (sext_inreg (trunc Int64Regs:$s), i8)),
@@ -2110,6 +2214,29 @@ def : Pat<(seteq Int1Regs:$a, Int1Regs:$b),
2110
2214
def : Pat<(setueq Int1Regs:$a, Int1Regs:$b),
2111
2215
(NOT1 (XORb1rr Int1Regs:$a, Int1Regs:$b))>;
2112
2216
2217
+ // comparisons of i8 extracted with BFE as i32
2218
+ def: Pat<(setgt (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
2219
+ (SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpGT)>;
2220
+ def: Pat<(setge (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
2221
+ (SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpGE)>;
2222
+ def: Pat<(setlt (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
2223
+ (SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpLT)>;
2224
+ def: Pat<(setle (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
2225
+ (SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpLE)>;
2226
+
2227
+ def: Pat<(setugt (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2228
+ (SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpHI)>;
2229
+ def: Pat<(setuge (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2230
+ (SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpHS)>;
2231
+ def: Pat<(setult (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2232
+ (SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpLO)>;
2233
+ def: Pat<(setule (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2234
+ (SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpLS)>;
2235
+ def: Pat<(seteq (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2236
+ (SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpEQ)>;
2237
+ def: Pat<(setne (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2238
+ (SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpNE)>;
2239
+
2113
2240
// i1 compare -> i32
2114
2241
def : Pat<(i32 (setne Int1Regs:$a, Int1Regs:$b)),
2115
2242
(SELP_u32ii -1, 0, (XORb1rr Int1Regs:$a, Int1Regs:$b))>;
@@ -2682,7 +2809,7 @@ foreach vt = [f16, bf16] in {
2682
2809
def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI16 Int16Regs:$src)>;
2683
2810
}
2684
2811
2685
- foreach vt = [v2f16, v2bf16, v2i16] in {
2812
+ foreach vt = [v2f16, v2bf16, v2i16, v4i8 ] in {
2686
2813
def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI32 Int32Regs:$src)>;
2687
2814
}
2688
2815
@@ -2995,8 +3122,8 @@ def: Pat<(i16 (bitconvert (vt Int16Regs:$a))),
2995
3122
(ProxyRegI16 Int16Regs:$a)>;
2996
3123
}
2997
3124
2998
- foreach ta = [v2f16, v2bf16, v2i16, i32] in {
2999
- foreach tb = [v2f16, v2bf16, v2i16, i32] in {
3125
+ foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in {
3126
+ foreach tb = [v2f16, v2bf16, v2i16, v4i8, i32] in {
3000
3127
if !ne(ta, tb) then {
3001
3128
def: Pat<(ta (bitconvert (tb Int32Regs:$a))),
3002
3129
(ProxyRegI32 Int32Regs:$a)>;
@@ -3342,6 +3469,9 @@ def : Pat<(i32 (trunc (srl Int64Regs:$s, (i32 32)))),
3342
3469
def : Pat<(i32 (trunc (sra Int64Regs:$s, (i32 32)))),
3343
3470
(I64toI32H Int64Regs:$s)>;
3344
3471
3472
+ def: Pat<(i32 (sext (extractelt (v2i16 Int32Regs:$src), 0))),
3473
+ (CVT_INREG_s32_s16 Int32Regs:$src)>;
3474
+
3345
3475
foreach vt = [v2f16, v2bf16, v2i16] in {
3346
3476
def : Pat<(extractelt (vt Int32Regs:$src), 0),
3347
3477
(I32toI16L Int32Regs:$src)>;
0 commit comments