Skip to content

Commit cbafb6f

Browse files
authoredOct 9, 2023
[NVPTX] Improve lowering of v4i8 (#67866)
Make v4i8 a legal type and plumb through lowering of relevant instructions.
1 parent 67b675e commit cbafb6f

15 files changed

+1897
-540
lines changed
 

‎llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,34 @@ void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
309309
const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
310310
O << Sym.getName();
311311
}
312+
313+
void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
314+
raw_ostream &O, const char *Modifier) {
315+
const MCOperand &MO = MI->getOperand(OpNum);
316+
int64_t Imm = MO.getImm();
317+
318+
switch (Imm) {
319+
default:
320+
return;
321+
case NVPTX::PTXPrmtMode::NONE:
322+
break;
323+
case NVPTX::PTXPrmtMode::F4E:
324+
O << ".f4e";
325+
break;
326+
case NVPTX::PTXPrmtMode::B4E:
327+
O << ".b4e";
328+
break;
329+
case NVPTX::PTXPrmtMode::RC8:
330+
O << ".rc8";
331+
break;
332+
case NVPTX::PTXPrmtMode::ECL:
333+
O << ".ecl";
334+
break;
335+
case NVPTX::PTXPrmtMode::ECR:
336+
O << ".ecr";
337+
break;
338+
case NVPTX::PTXPrmtMode::RC16:
339+
O << ".rc16";
340+
break;
341+
}
342+
}

‎llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
4747
raw_ostream &O, const char *Modifier = nullptr);
4848
void printProtoIdent(const MCInst *MI, int OpNum,
4949
raw_ostream &O, const char *Modifier = nullptr);
50+
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
51+
const char *Modifier = nullptr);
5052
};
5153

5254
}

‎llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@ enum CmpMode {
181181
FTZ_FLAG = 0x100
182182
};
183183
}
184+
185+
namespace PTXPrmtMode {
186+
enum PrmtMode {
187+
NONE,
188+
F4E,
189+
B4E,
190+
RC8,
191+
ECL,
192+
ECR,
193+
RC16,
194+
};
195+
}
184196
}
185197
void initializeNVPTXDAGToDAGISelPass(PassRegistry &);
186198
} // namespace llvm

‎llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "MCTargetDesc/NVPTXBaseInfo.h"
1515
#include "NVPTXUtilities.h"
1616
#include "llvm/Analysis/ValueTracking.h"
17+
#include "llvm/CodeGen/ISDOpcodes.h"
1718
#include "llvm/IR/GlobalValue.h"
1819
#include "llvm/IR/Instructions.h"
1920
#include "llvm/IR/IntrinsicsNVPTX.h"
@@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
829830
case MVT::v2f16:
830831
case MVT::v2bf16:
831832
case MVT::v2i16:
833+
case MVT::v4i8:
832834
return Opcode_i32;
833835
case MVT::f32:
834836
return Opcode_f32;
@@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
910912
// Vector Setting
911913
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
912914
if (SimpleVT.isVector()) {
913-
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
915+
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
916+
"Unexpected vector type");
914917
// v2f16/v2bf16/v2i16 is loaded using ld.b32
915918
fromTypeWidth = 32;
916919
}
@@ -1254,19 +1257,23 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12541257
SDLoc DL(N);
12551258
SDNode *LD;
12561259
SDValue Base, Offset, Addr;
1260+
EVT OrigType = N->getValueType(0);
12571261

12581262
EVT EltVT = Mem->getMemoryVT();
12591263
unsigned NumElts = 1;
12601264
if (EltVT.isVector()) {
12611265
NumElts = EltVT.getVectorNumElements();
12621266
EltVT = EltVT.getVectorElementType();
12631267
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1264-
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
1265-
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
1266-
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
1268+
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
1269+
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
1270+
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
12671271
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1268-
EltVT = N->getValueType(0);
1272+
EltVT = OrigType;
12691273
NumElts /= 2;
1274+
} else if (OrigType == MVT::v4i8) {
1275+
EltVT = OrigType;
1276+
NumElts = 1;
12701277
}
12711278
}
12721279

@@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
16011608
// concept of sign-/zero-extension, so emulate it here by adding an explicit
16021609
// CVT instruction. Ptxas should clean up any redundancies here.
16031610

1604-
EVT OrigType = N->getValueType(0);
16051611
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
16061612

16071613
if (OrigType != EltVT &&
@@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16791685
MVT ScalarVT = SimpleVT.getScalarType();
16801686
unsigned toTypeWidth = ScalarVT.getSizeInBits();
16811687
if (SimpleVT.isVector()) {
1682-
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
1688+
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
1689+
"Unexpected vector type");
16831690
// v2x16 is stored using st.b32
16841691
toTypeWidth = 32;
16851692
}

‎llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 228 additions & 47 deletions
Large diffs are not rendered by default.

‎llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ enum NodeType : unsigned {
5757
MUL_WIDE_UNSIGNED,
5858
IMAD,
5959
SETP_F16X2,
60+
BFE,
61+
BFI,
62+
PRMT,
6063
Dummy,
6164

6265
LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -590,6 +593,8 @@ class NVPTXTargetLowering : public TargetLowering {
590593
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
591594
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
592595
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
596+
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
597+
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
593598

594599
SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
595600
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;

‎llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 165 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def CmpLT : PatLeaf<(i32 2)>;
7676
def CmpLE : PatLeaf<(i32 3)>;
7777
def CmpGT : PatLeaf<(i32 4)>;
7878
def CmpGE : PatLeaf<(i32 5)>;
79+
def CmpLO : PatLeaf<(i32 6)>;
80+
def CmpLS : PatLeaf<(i32 7)>;
81+
def CmpHI : PatLeaf<(i32 8)>;
82+
def CmpHS : PatLeaf<(i32 9)>;
7983
def CmpEQU : PatLeaf<(i32 10)>;
8084
def CmpNEU : PatLeaf<(i32 11)>;
8185
def CmpLTU : PatLeaf<(i32 12)>;
@@ -107,6 +111,21 @@ def VecElement : Operand<i32> {
107111
let PrintMethod = "printVecElement";
108112
}
109113

114+
// PRMT modes
115+
// These must match the enum in NVPTX.h
116+
def PrmtNONE : PatLeaf<(i32 0x0)>;
117+
def PrmtF4E : PatLeaf<(i32 0x1)>;
118+
def PrmtB4E : PatLeaf<(i32 0x2)>;
119+
def PrmtRC8 : PatLeaf<(i32 0x3)>;
120+
def PrmtECL : PatLeaf<(i32 0x4)>;
121+
def PrmtECR : PatLeaf<(i32 0x5)>;
122+
def PrmtRC16 : PatLeaf<(i32 0x6)>;
123+
124+
def PrmtMode : Operand<i32> {
125+
let PrintMethod = "printPrmtMode";
126+
}
127+
128+
110129
//===----------------------------------------------------------------------===//
111130
// NVPTX Instruction Predicate Definitions
112131
//===----------------------------------------------------------------------===//
@@ -742,7 +761,7 @@ defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
742761
// def v2f16imm : Operand<v2f16>;
743762
// defm SELP_f16x2 : SELP_PATTERN<"b32", v2f16, Int32Regs, v2f16imm, imm>;
744763

745-
foreach vt = [v2f16, v2bf16, v2i16] in {
764+
foreach vt = [v2f16, v2bf16, v2i16, v4i8] in {
746765
def : Pat<(vt (select Int1Regs:$p, (vt Int32Regs:$a), (vt Int32Regs:$b))),
747766
(SELP_b32rr Int32Regs:$a, Int32Regs:$b, Int1Regs:$p)>;
748767
}
@@ -1486,23 +1505,24 @@ defm OR : BITWISE<"or", or>;
14861505
defm AND : BITWISE<"and", and>;
14871506
defm XOR : BITWISE<"xor", xor>;
14881507

1489-
// Lower logical v2i16 ops as bitwise ops on b32.
1490-
def: Pat<(or (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
1491-
(ORb32rr Int32Regs:$a, Int32Regs:$b)>;
1492-
def: Pat<(xor (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
1493-
(XORb32rr Int32Regs:$a, Int32Regs:$b)>;
1494-
def: Pat<(and (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
1495-
(ANDb32rr Int32Regs:$a, Int32Regs:$b)>;
1496-
1497-
// The constants get legalized into a bitcast from i32, so that's what we need
1498-
// to match here.
1499-
def: Pat<(or Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
1500-
(ORb32ri Int32Regs:$a, imm:$b)>;
1501-
def: Pat<(xor Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
1502-
(XORb32ri Int32Regs:$a, imm:$b)>;
1503-
def: Pat<(and Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
1504-
(ANDb32ri Int32Regs:$a, imm:$b)>;
1505-
1508+
// Lower logical v2i16/v4i8 ops as bitwise ops on b32.
1509+
foreach vt = [v2i16, v4i8] in {
1510+
def: Pat<(or (vt Int32Regs:$a), (vt Int32Regs:$b)),
1511+
(ORb32rr Int32Regs:$a, Int32Regs:$b)>;
1512+
def: Pat<(xor (vt Int32Regs:$a), (vt Int32Regs:$b)),
1513+
(XORb32rr Int32Regs:$a, Int32Regs:$b)>;
1514+
def: Pat<(and (vt Int32Regs:$a), (vt Int32Regs:$b)),
1515+
(ANDb32rr Int32Regs:$a, Int32Regs:$b)>;
1516+
1517+
// The constants get legalized into a bitcast from i32, so that's what we need
1518+
// to match here.
1519+
def: Pat<(or Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
1520+
(ORb32ri Int32Regs:$a, imm:$b)>;
1521+
def: Pat<(xor Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
1522+
(XORb32ri Int32Regs:$a, imm:$b)>;
1523+
def: Pat<(and Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
1524+
(ANDb32ri Int32Regs:$a, imm:$b)>;
1525+
}
15061526

15071527
def NOT1 : NVPTXInst<(outs Int1Regs:$dst), (ins Int1Regs:$src),
15081528
"not.pred \t$dst, $src;",
@@ -1737,34 +1757,119 @@ def FUNSHFRCLAMP :
17371757
// restriction in PTX?
17381758
//
17391759
// dest and src may be int32 or int64, but start and end are always int32.
1740-
multiclass BFX<string Instr, RegisterClass RC> {
1760+
def SDTBFE :
1761+
SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>,
1762+
SDTCisVT<2, i32>, SDTCisVT<3, i32>]>;
1763+
def bfe : SDNode<"NVPTXISD::BFE", SDTBFE>;
1764+
1765+
def SDTBFI :
1766+
SDTypeProfile<1, 4, [SDTCisInt<0>, SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>,
1767+
SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
1768+
def bfi : SDNode<"NVPTXISD::BFI", SDTBFI>;
1769+
1770+
def SDTPRMT :
1771+
SDTypeProfile<1, 4, [SDTCisVT<0, i32>, SDTCisVT<1, i32>,
1772+
SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>,]>;
1773+
def prmt : SDNode<"NVPTXISD::PRMT", SDTPRMT>;
1774+
1775+
multiclass BFE<string Instr, ValueType T, RegisterClass RC> {
17411776
def rrr
17421777
: NVPTXInst<(outs RC:$d),
17431778
(ins RC:$a, Int32Regs:$b, Int32Regs:$c),
1744-
!strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
1779+
!strconcat(Instr, " \t$d, $a, $b, $c;"),
1780+
[(set (T RC:$d), (bfe (T RC:$a), (i32 Int32Regs:$b), (i32 Int32Regs:$c)))]>;
17451781
def rri
17461782
: NVPTXInst<(outs RC:$d),
17471783
(ins RC:$a, Int32Regs:$b, i32imm:$c),
1748-
!strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
1784+
!strconcat(Instr, " \t$d, $a, $b, $c;"),
1785+
[(set (T RC:$d), (bfe (T RC:$a), (i32 Int32Regs:$b), (i32 imm:$c)))]>;
17491786
def rii
17501787
: NVPTXInst<(outs RC:$d),
17511788
(ins RC:$a, i32imm:$b, i32imm:$c),
1752-
!strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
1789+
!strconcat(Instr, " \t$d, $a, $b, $c;"),
1790+
[(set (T RC:$d), (bfe (T RC:$a), (i32 imm:$b), (i32 imm:$c)))]>;
1791+
}
1792+
1793+
multiclass BFI<string Instr, ValueType T, RegisterClass RC, Operand ImmCls> {
1794+
def rrrr
1795+
: NVPTXInst<(outs RC:$f),
1796+
(ins RC:$a, RC:$b, Int32Regs:$c, Int32Regs:$d),
1797+
!strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1798+
[(set (T RC:$f), (bfi (T RC:$a), (T RC:$b), (i32 Int32Regs:$c), (i32 Int32Regs:$d)))]>;
1799+
def rrri
1800+
: NVPTXInst<(outs RC:$f),
1801+
(ins RC:$a, RC:$b, Int32Regs:$c, i32imm:$d),
1802+
!strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1803+
[(set (T RC:$f), (bfi (T RC:$a), (T RC:$b), (i32 Int32Regs:$c), (i32 imm:$d)))]>;
1804+
def rrii
1805+
: NVPTXInst<(outs RC:$f),
1806+
(ins RC:$a, RC:$b, i32imm:$c, i32imm:$d),
1807+
!strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1808+
[(set (T RC:$f), (bfi (T RC:$a), (T RC:$b), (i32 imm:$c), (i32 imm:$d)))]>;
1809+
def irrr
1810+
: NVPTXInst<(outs RC:$f),
1811+
(ins ImmCls:$a, RC:$b, Int32Regs:$c, Int32Regs:$d),
1812+
!strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1813+
[(set (T RC:$f), (bfi (T imm:$a), (T RC:$b), (i32 Int32Regs:$c), (i32 Int32Regs:$d)))]>;
1814+
def irri
1815+
: NVPTXInst<(outs RC:$f),
1816+
(ins ImmCls:$a, RC:$b, Int32Regs:$c, i32imm:$d),
1817+
!strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1818+
[(set (T RC:$f), (bfi (T imm:$a), (T RC:$b), (i32 Int32Regs:$c), (i32 imm:$d)))]>;
1819+
def irii
1820+
: NVPTXInst<(outs RC:$f),
1821+
(ins ImmCls:$a, RC:$b, i32imm:$c, i32imm:$d),
1822+
!strconcat(Instr, " \t$f, $a, $b, $c, $d;"),
1823+
[(set (T RC:$f), (bfi (T imm:$a), (T RC:$b), (i32 imm:$c), (i32 imm:$d)))]>;
1824+
}
1825+
1826+
multiclass PRMT<ValueType T, RegisterClass RC> {
1827+
def rrr
1828+
: NVPTXInst<(outs RC:$d),
1829+
(ins RC:$a, Int32Regs:$b, Int32Regs:$c, PrmtMode:$mode),
1830+
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1831+
[(set (T RC:$d), (prmt (T RC:$a), (T RC:$b), (i32 Int32Regs:$c), imm:$mode))]>;
1832+
def rri
1833+
: NVPTXInst<(outs RC:$d),
1834+
(ins RC:$a, Int32Regs:$b, i32imm:$c, PrmtMode:$mode),
1835+
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1836+
[(set (T RC:$d), (prmt (T RC:$a), (T RC:$b), (i32 imm:$c), imm:$mode))]>;
1837+
def rii
1838+
: NVPTXInst<(outs RC:$d),
1839+
(ins RC:$a, i32imm:$b, i32imm:$c, PrmtMode:$mode),
1840+
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1841+
[(set (T RC:$d), (prmt (T RC:$a), (T imm:$b), (i32 imm:$c), imm:$mode))]>;
17531842
}
17541843

17551844
let hasSideEffects = false in {
1756-
defm BFE_S32 : BFX<"bfe.s32", Int32Regs>;
1757-
defm BFE_U32 : BFX<"bfe.u32", Int32Regs>;
1758-
defm BFE_S64 : BFX<"bfe.s64", Int64Regs>;
1759-
defm BFE_U64 : BFX<"bfe.u64", Int64Regs>;
1845+
defm BFE_S32 : BFE<"bfe.s32", i32, Int32Regs>;
1846+
defm BFE_U32 : BFE<"bfe.u32", i32, Int32Regs>;
1847+
defm BFE_S64 : BFE<"bfe.s64", i64, Int64Regs>;
1848+
defm BFE_U64 : BFE<"bfe.u64", i64, Int64Regs>;
1849+
1850+
defm BFI_B32 : BFI<"bfi.b32", i32, Int32Regs, i32imm>;
1851+
defm BFI_B64 : BFI<"bfi.b64", i64, Int64Regs, i64imm>;
17601852

1761-
defm BFI_S32 : BFX<"bfi.s32", Int32Regs>;
1762-
defm BFI_U32 : BFX<"bfi.u32", Int32Regs>;
1763-
defm BFI_S64 : BFX<"bfi.s64", Int64Regs>;
1764-
defm BFI_U64 : BFX<"bfi.u64", Int64Regs>;
1853+
defm PRMT_B32 : PRMT<i32, Int32Regs>;
17651854
}
17661855

1767-
// Common byte extraction patterns
1856+
1857+
// byte extraction + signed/unsigned extension to i32.
1858+
def : Pat<(i32 (sext_inreg (bfe (i32 Int32Regs:$s), (i32 Int32Regs:$o), 8), i8)),
1859+
(BFE_S32rri Int32Regs:$s, Int32Regs:$o, 8)>;
1860+
def : Pat<(i32 (sext_inreg (bfe (i32 Int32Regs:$s), (i32 imm:$o), 8), i8)),
1861+
(BFE_S32rii Int32Regs:$s, imm:$o, 8)>;
1862+
def : Pat<(i32 (and (bfe (i32 Int32Regs:$s), (i32 Int32Regs:$o), 8), 255)),
1863+
(BFE_U32rri Int32Regs:$s, Int32Regs:$o, 8)>;
1864+
def : Pat<(i32 (and (bfe (i32 Int32Regs:$s), (i32 imm:$o), 8), 255)),
1865+
(BFE_U32rii Int32Regs:$s, imm:$o, 8)>;
1866+
1867+
// byte extraction + signed extension to i16
1868+
def : Pat<(i16 (sext_inreg (trunc (bfe (i32 Int32Regs:$s), (i32 imm:$o), 8)), i8)),
1869+
(CVT_s8_s32 (BFE_S32rii Int32Regs:$s, imm:$o, 8), CvtNONE)>;
1870+
1871+
1872+
// Byte extraction via shift/trunc/sext
17681873
def : Pat<(i16 (sext_inreg (trunc Int32Regs:$s), i8)),
17691874
(CVT_s8_s32 Int32Regs:$s, CvtNONE)>;
17701875
def : Pat<(i16 (sext_inreg (trunc (srl (i32 Int32Regs:$s), (i32 imm:$o))), i8)),
@@ -1773,7 +1878,6 @@ def : Pat<(sext_inreg (srl (i32 Int32Regs:$s), (i32 imm:$o)), i8),
17731878
(BFE_S32rii Int32Regs:$s, imm:$o, 8)>;
17741879
def : Pat<(i16 (sra (i16 (trunc Int32Regs:$s)), (i32 8))),
17751880
(CVT_s8_s32 (BFE_S32rii Int32Regs:$s, 8, 8), CvtNONE)>;
1776-
17771881
def : Pat<(sext_inreg (srl (i64 Int64Regs:$s), (i32 imm:$o)), i8),
17781882
(BFE_S64rii Int64Regs:$s, imm:$o, 8)>;
17791883
def : Pat<(i16 (sext_inreg (trunc Int64Regs:$s), i8)),
@@ -2110,6 +2214,29 @@ def : Pat<(seteq Int1Regs:$a, Int1Regs:$b),
21102214
def : Pat<(setueq Int1Regs:$a, Int1Regs:$b),
21112215
(NOT1 (XORb1rr Int1Regs:$a, Int1Regs:$b))>;
21122216

2217+
// comparisons of i8 extracted with BFE as i32
2218+
def: Pat<(setgt (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
2219+
(SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpGT)>;
2220+
def: Pat<(setge (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
2221+
(SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpGE)>;
2222+
def: Pat<(setlt (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
2223+
(SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpLT)>;
2224+
def: Pat<(setle (sext_inreg (trunc Int32Regs:$a), i8), (sext_inreg (trunc Int32Regs:$b), i8)),
2225+
(SETP_s32rr Int32Regs:$a, Int32Regs:$b, CmpLE)>;
2226+
2227+
def: Pat<(setugt (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2228+
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpHI)>;
2229+
def: Pat<(setuge (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2230+
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpHS)>;
2231+
def: Pat<(setult (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2232+
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpLO)>;
2233+
def: Pat<(setule (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2234+
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpLS)>;
2235+
def: Pat<(seteq (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2236+
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpEQ)>;
2237+
def: Pat<(setne (i16 (and (trunc Int32Regs:$a), 255)), (i16 (and (trunc Int32Regs:$b), 255))),
2238+
(SETP_u32rr Int32Regs:$a, Int32Regs:$b, CmpNE)>;
2239+
21132240
// i1 compare -> i32
21142241
def : Pat<(i32 (setne Int1Regs:$a, Int1Regs:$b)),
21152242
(SELP_u32ii -1, 0, (XORb1rr Int1Regs:$a, Int1Regs:$b))>;
@@ -2682,7 +2809,7 @@ foreach vt = [f16, bf16] in {
26822809
def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI16 Int16Regs:$src)>;
26832810
}
26842811

2685-
foreach vt = [v2f16, v2bf16, v2i16] in {
2812+
foreach vt = [v2f16, v2bf16, v2i16, v4i8] in {
26862813
def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI32 Int32Regs:$src)>;
26872814
}
26882815

@@ -2995,8 +3122,8 @@ def: Pat<(i16 (bitconvert (vt Int16Regs:$a))),
29953122
(ProxyRegI16 Int16Regs:$a)>;
29963123
}
29973124

2998-
foreach ta = [v2f16, v2bf16, v2i16, i32] in {
2999-
foreach tb = [v2f16, v2bf16, v2i16, i32] in {
3125+
foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in {
3126+
foreach tb = [v2f16, v2bf16, v2i16, v4i8, i32] in {
30003127
if !ne(ta, tb) then {
30013128
def: Pat<(ta (bitconvert (tb Int32Regs:$a))),
30023129
(ProxyRegI32 Int32Regs:$a)>;
@@ -3342,6 +3469,9 @@ def : Pat<(i32 (trunc (srl Int64Regs:$s, (i32 32)))),
33423469
def : Pat<(i32 (trunc (sra Int64Regs:$s, (i32 32)))),
33433470
(I64toI32H Int64Regs:$s)>;
33443471

3472+
def: Pat<(i32 (sext (extractelt (v2i16 Int32Regs:$src), 0))),
3473+
(CVT_INREG_s32_s16 Int32Regs:$src)>;
3474+
33453475
foreach vt = [v2f16, v2bf16, v2i16] in {
33463476
def : Pat<(extractelt (vt Int32Regs:$src), 0),
33473477
(I32toI16L Int32Regs:$src)>;

‎llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ foreach i = 0...31 in {
5858
//===----------------------------------------------------------------------===//
5959
def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>;
6060
def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
61-
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16], 32,
61+
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
6262
(add (sequence "R%u", 0, 4),
6363
VRFrame32, VRFrameLocal32)>;
6464
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;

‎llvm/test/CodeGen/NVPTX/extractelement.ll

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ define i16 @test_v2i8(i16 %a) {
1818

1919
; CHECK-LABEL: test_v4i8
2020
; CHECK: ld.param.u32 [[R:%r[0-9+]]], [test_v4i8_param_0];
21-
; CHECK-DAG: cvt.s8.s32 [[E0:%rs[0-9+]]], [[R]];
21+
; CHECK-DAG: bfe.s32 [[R0:%r[0-9+]]], [[R]], 0, 8;
22+
; CHECK-DAG: cvt.s8.s32 [[E0:%rs[0-9+]]], [[R0]];
2223
; CHECK-DAG: bfe.s32 [[R1:%r[0-9+]]], [[R]], 8, 8;
2324
; CHECK-DAG: cvt.s8.s32 [[E1:%rs[0-9+]]], [[R1]];
2425
; CHECK-DAG: bfe.s32 [[R2:%r[0-9+]]], [[R]], 16, 8;
@@ -41,6 +42,58 @@ define i16 @test_v4i8(i32 %a) {
4142
ret i16 %r
4243
}
4344

45+
; CHECK-LABEL: test_v4i8_s32
46+
; CHECK: ld.param.u32 [[R:%r[0-9+]]], [test_v4i8_s32_param_0];
47+
; CHECK-DAG: bfe.s32 [[R0:%r[0-9+]]], [[R]], 0, 8;
48+
; CHECK-DAG: bfe.s32 [[R1:%r[0-9+]]], [[R]], 8, 8;
49+
; CHECK-DAG: bfe.s32 [[R2:%r[0-9+]]], [[R]], 16, 8;
50+
; CHECK-DAG: bfe.s32 [[R3:%r[0-9+]]], [[R]], 24, 8;
51+
; CHECK-DAG: add.s32 [[R01:%r[0-9+]]], [[R0]], [[R1]]
52+
; CHECK-DAG: add.s32 [[R23:%r[0-9+]]], [[R2]], [[R3]]
53+
; CHECK-DAG: add.s32 [[R0123:%r[0-9+]]], [[R01]], [[R23]]
54+
define i32 @test_v4i8_s32(i32 %a) {
55+
%v = bitcast i32 %a to <4 x i8>
56+
%r0 = extractelement <4 x i8> %v, i64 0
57+
%r1 = extractelement <4 x i8> %v, i64 1
58+
%r2 = extractelement <4 x i8> %v, i64 2
59+
%r3 = extractelement <4 x i8> %v, i64 3
60+
%r0i = sext i8 %r0 to i32
61+
%r1i = sext i8 %r1 to i32
62+
%r2i = sext i8 %r2 to i32
63+
%r3i = sext i8 %r3 to i32
64+
%r01 = add i32 %r0i, %r1i
65+
%r23 = add i32 %r2i, %r3i
66+
%r = add i32 %r01, %r23
67+
ret i32 %r
68+
}
69+
70+
; CHECK-LABEL: test_v4i8_u32
71+
; CHECK: ld.param.u32 [[R:%r[0-9+]]], [test_v4i8_u32_param_0];
72+
; CHECK-DAG: bfe.u32 [[R0:%r[0-9+]]], [[R]], 0, 8;
73+
; CHECK-DAG: bfe.u32 [[R1:%r[0-9+]]], [[R]], 8, 8;
74+
; CHECK-DAG: bfe.u32 [[R2:%r[0-9+]]], [[R]], 16, 8;
75+
; CHECK-DAG: bfe.u32 [[R3:%r[0-9+]]], [[R]], 24, 8;
76+
; CHECK-DAG: add.s32 [[R01:%r[0-9+]]], [[R0]], [[R1]]
77+
; CHECK-DAG: add.s32 [[R23:%r[0-9+]]], [[R2]], [[R3]]
78+
; CHECK-DAG: add.s32 [[R0123:%r[0-9+]]], [[R01]], [[R23]]
79+
define i32 @test_v4i8_u32(i32 %a) {
80+
%v = bitcast i32 %a to <4 x i8>
81+
%r0 = extractelement <4 x i8> %v, i64 0
82+
%r1 = extractelement <4 x i8> %v, i64 1
83+
%r2 = extractelement <4 x i8> %v, i64 2
84+
%r3 = extractelement <4 x i8> %v, i64 3
85+
%r0i = zext i8 %r0 to i32
86+
%r1i = zext i8 %r1 to i32
87+
%r2i = zext i8 %r2 to i32
88+
%r3i = zext i8 %r3 to i32
89+
%r01 = add i32 %r0i, %r1i
90+
%r23 = add i32 %r2i, %r3i
91+
%r = add i32 %r01, %r23
92+
ret i32 %r
93+
}
94+
95+
96+
4497
; CHECK-LABEL: test_v8i8
4598
; CHECK: ld.param.u64 [[R:%rd[0-9+]]], [test_v8i8_param_0];
4699
; CHECK-DAG: cvt.s8.s64 [[E0:%rs[0-9+]]], [[R]];

‎llvm/test/CodeGen/NVPTX/i16x2-instructions.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
; RUN: %if ptxas %{ \
66
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \
77
; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
8-
; RUN: | %ptxas-verify -arch=sm_53 \
8+
; RUN: | %ptxas-verify -arch=sm_90 \
99
; RUN: %}
1010
; ## No support for i16x2 instructions
1111
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_53 -asm-verbose=false \

‎llvm/test/CodeGen/NVPTX/i8x4-instructions.ll

Lines changed: 1272 additions & 0 deletions
Large diffs are not rendered by default.

‎llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ define void @foo12(ptr noalias readonly %from, ptr %to) {
130130
}
131131

132132
; SM20-LABEL: .visible .entry foo13(
133-
; SM20: ld.global.v4.u8
133+
; SM20: ld.global.u32
134134
; SM35-LABEL: .visible .entry foo13(
135-
; SM35: ld.global.nc.v4.u8
135+
; SM35: ld.global.nc.u32
136136
define void @foo13(ptr noalias readonly %from, ptr %to) {
137137
%1 = load <4 x i8>, ptr %from
138138
store <4 x i8> %1, ptr %to

‎llvm/test/CodeGen/NVPTX/param-load-store.ll

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,16 @@ define signext i8 @test_i8s(i8 signext %a) {
212212
; CHECK: .func (.param .align 4 .b8 func_retval0[4])
213213
; CHECK-LABEL: test_v3i8(
214214
; CHECK-NEXT: .param .align 4 .b8 test_v3i8_param_0[4]
215-
; CHECK-DAG: ld.param.u8 [[E2:%rs[0-9]+]], [test_v3i8_param_0+2];
216-
; CHECK-DAG: ld.param.v2.u8 {[[E0:%rs[0-9]+]], [[E1:%rs[0-9]+]]}, [test_v3i8_param_0];
215+
; CHECK: ld.param.u32 [[R:%r[0-9]+]], [test_v3i8_param_0];
217216
; CHECK: .param .align 4 .b8 param0[4];
218-
; CHECK: st.param.v2.b8 [param0+0], {[[E0]], [[E1]]};
219-
; CHECK: st.param.b8 [param0+2], [[E2]];
217+
; CHECK: st.param.b32 [param0+0], [[R]]
220218
; CHECK: .param .align 4 .b8 retval0[4];
221219
; CHECK: call.uni (retval0),
222220
; CHECK-NEXT: test_v3i8,
223-
; CHECK-DAG: ld.param.v2.b8 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]]}, [retval0+0];
224-
; CHECK-DAG: ld.param.b8 [[RE2:%rs[0-9]+]], [retval0+2];
225-
; CHECK-DAG: st.param.v2.b8 [func_retval0+0], {[[RE0]], [[RE1]]};
226-
; CHECK-DAG: st.param.b8 [func_retval0+2], [[RE2]];
221+
; CHECK: ld.param.b32 [[RE:%r[0-9]+]], [retval0+0];
222+
; v4i8/i32->{v3i8 elements}->v4i8/i32 conversion is messy and not very
223+
; interesting here, so it's skipped.
224+
; CHECK: st.param.b32 [func_retval0+0],
227225
; CHECK-NEXT: ret;
228226
define <3 x i8> @test_v3i8(<3 x i8> %a) {
229227
%r = tail call <3 x i8> @test_v3i8(<3 x i8> %a);
@@ -233,14 +231,14 @@ define <3 x i8> @test_v3i8(<3 x i8> %a) {
233231
; CHECK: .func (.param .align 4 .b8 func_retval0[4])
234232
; CHECK-LABEL: test_v4i8(
235233
; CHECK-NEXT: .param .align 4 .b8 test_v4i8_param_0[4]
236-
; CHECK: ld.param.v4.u8 {[[E0:%rs[0-9]+]], [[E1:%rs[0-9]+]], [[E2:%rs[0-9]+]], [[E3:%rs[0-9]+]]}, [test_v4i8_param_0]
234+
; CHECK: ld.param.u32 [[R:%r[0-9]+]], [test_v4i8_param_0]
237235
; CHECK: .param .align 4 .b8 param0[4];
238-
; CHECK: st.param.v4.b8 [param0+0], {[[E0]], [[E1]], [[E2]], [[E3]]};
236+
; CHECK: st.param.b32 [param0+0], [[R]];
239237
; CHECK: .param .align 4 .b8 retval0[4];
240238
; CHECK: call.uni (retval0),
241239
; CHECK-NEXT: test_v4i8,
242-
; CHECK: ld.param.v4.b8 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]], [[RE2:%rs[0-9]+]], [[RE3:%rs[0-9]+]]}, [retval0+0];
243-
; CHECK: st.param.v4.b8 [func_retval0+0], {[[RE0]], [[RE1]], [[RE2]], [[RE3]]}
240+
; CHECK: ld.param.b32 [[RET:%r[0-9]+]], [retval0+0];
241+
; CHECK: st.param.b32 [func_retval0+0], [[RET]];
244242
; CHECK-NEXT: ret;
245243
define <4 x i8> @test_v4i8(<4 x i8> %a) {
246244
%r = tail call <4 x i8> @test_v4i8(<4 x i8> %a);
@@ -250,10 +248,10 @@ define <4 x i8> @test_v4i8(<4 x i8> %a) {
250248
; CHECK: .func (.param .align 8 .b8 func_retval0[8])
251249
; CHECK-LABEL: test_v5i8(
252250
; CHECK-NEXT: .param .align 8 .b8 test_v5i8_param_0[8]
251+
; CHECK-DAG: ld.param.u32 [[E0:%r[0-9]+]], [test_v5i8_param_0]
253252
; CHECK-DAG: ld.param.u8 [[E4:%rs[0-9]+]], [test_v5i8_param_0+4];
254-
; CHECK-DAG: ld.param.v4.u8 {[[E0:%rs[0-9]+]], [[E1:%rs[0-9]+]], [[E2:%rs[0-9]+]], [[E3:%rs[0-9]+]]}, [test_v5i8_param_0]
255253
; CHECK: .param .align 8 .b8 param0[8];
256-
; CHECK-DAG: st.param.v4.b8 [param0+0], {[[E0]], [[E1]], [[E2]], [[E3]]};
254+
; CHECK-DAG: st.param.v4.b8 [param0+0],
257255
; CHECK-DAG: st.param.b8 [param0+4], [[E4]];
258256
; CHECK: .param .align 8 .b8 retval0[8];
259257
; CHECK: call.uni (retval0),

‎llvm/test/CodeGen/NVPTX/unfold-masked-merge-vector-variablemask.ll

Lines changed: 96 additions & 429 deletions
Large diffs are not rendered by default.

‎llvm/test/CodeGen/NVPTX/vec8.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@ target triple = "nvptx-unknown-cuda"
55

66
; CHECK: .visible .func foo
77
define void @foo(<8 x i8> %a, ptr %b) {
8-
; CHECK-DAG: ld.param.v4.u8 {[[E0:%rs[0-9]+]], [[E1:%rs[0-9]+]], [[E2:%rs[0-9]+]], [[E3:%rs[0-9]+]]}, [foo_param_0]
9-
; CHECK-DAG: ld.param.v4.u8 {[[E4:%rs[0-9]+]], [[E5:%rs[0-9]+]], [[E6:%rs[0-9]+]], [[E7:%rs[0-9]+]]}, [foo_param_0+4]
8+
; CHECK-DAG: ld.param.v2.u32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [foo_param_0]
109
; CHECK-DAG: ld.param.u64 %[[B:rd[0-9+]]], [foo_param_1]
11-
; CHECK: add.s16 [[T:%rs[0-9+]]], [[E1]], [[E6]];
10+
; CHECK: add.s16 [[T:%rs[0-9+]]],
1211
; CHECK: st.u8 [%[[B]]], [[T]];
1312
%t0 = extractelement <8 x i8> %a, i32 1
1413
%t1 = extractelement <8 x i8> %a, i32 6

0 commit comments

Comments
 (0)
Please sign in to comment.