Skip to content

[NVPTX] Improve lowering of v4i8 #67866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 9, 2023
Merged
31 changes: 31 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
@@ -309,3 +309,34 @@ void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
O << Sym.getName();
}

void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();

switch (Imm) {
default:
return;
case NVPTX::PTXPrmtMode::NONE:
break;
case NVPTX::PTXPrmtMode::F4E:
O << ".f4e";
break;
case NVPTX::PTXPrmtMode::B4E:
O << ".b4e";
break;
case NVPTX::PTXPrmtMode::RC8:
O << ".rc8";
break;
case NVPTX::PTXPrmtMode::ECL:
O << ".ecl";
break;
case NVPTX::PTXPrmtMode::ECR:
O << ".ecr";
break;
case NVPTX::PTXPrmtMode::RC16:
O << ".rc16";
break;
}
}
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
Original file line number Diff line number Diff line change
@@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
raw_ostream &O, const char *Modifier = nullptr);
void printProtoIdent(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier = nullptr);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
};

}
12 changes: 12 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
@@ -181,6 +181,18 @@ enum CmpMode {
FTZ_FLAG = 0x100
};
}

namespace PTXPrmtMode {
enum PrmtMode {
NONE,
F4E,
B4E,
RC8,
ECL,
ECR,
RC16,
};
}
}
void initializeNVPTXDAGToDAGISelPass(PassRegistry &);
} // namespace llvm
21 changes: 14 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
@@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
case MVT::v4i8:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
@@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
"Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}
@@ -1254,19 +1257,23 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
SDLoc DL(N);
SDNode *LD;
SDValue Base, Offset, Addr;
EVT OrigType = N->getValueType(0);

EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
EltVT = OrigType;
NumElts /= 2;
} else if (OrigType == MVT::v4i8) {
EltVT = OrigType;
NumElts = 1;
}
}

@@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// concept of sign-/zero-extension, so emulate it here by adding an explicit
// CVT instruction. Ptxas should clean up any redundancies here.

EVT OrigType = N->getValueType(0);
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);

if (OrigType != EltVT &&
@@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
"Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}
Loading