Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a1987be

Browse files
authoredOct 31, 2024··
Reland "[NVPTX] Prefer prmt.b32 over bfi.b32" (#114326)
Fix [failure](#110766 (comment)) identified by @akuegel. --- In [[NVPTX] Improve lowering of v4i8](cbafb6f) @Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX instructions. @Artem-B did this because: (#67866 (comment)) Under the hood byte extraction/insertion ends up as BFI/BFE instructions, so we may as well do that in PTX, too. https://godbolt.org/z/Tb3zWbj9b However, the example that @Artem-B linked was targeting sm_52. On modern architectures, ptxas uses prmt.b32. [Example](https://godbolt.org/z/Ye4W1n84o). Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
1 parent 19b4f17 commit a1987be

File tree

3 files changed

+339
-334
lines changed

3 files changed

+339
-334
lines changed
 

‎llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,32 +2318,33 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
23182318
EVT VT = Op->getValueType(0);
23192319
if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
23202320
return Op;
2321-
23222321
SDLoc DL(Op);
23232322

23242323
if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
23252324
return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
23262325
isa<ConstantFPSDNode>(Operand);
23272326
})) {
2327+
if (VT != MVT::v4i8)
2328+
return Op;
23282329
// Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
23292330
// to optimize calculation of constant parts.
2330-
if (VT == MVT::v4i8) {
2331-
SDValue C8 = DAG.getConstant(8, DL, MVT::i32);
2332-
SDValue E01 = DAG.getNode(
2333-
NVPTXISD::BFI, DL, MVT::i32,
2334-
DAG.getAnyExtOrTrunc(Op->getOperand(1), DL, MVT::i32),
2335-
DAG.getAnyExtOrTrunc(Op->getOperand(0), DL, MVT::i32), C8, C8);
2336-
SDValue E012 =
2337-
DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2338-
DAG.getAnyExtOrTrunc(Op->getOperand(2), DL, MVT::i32),
2339-
E01, DAG.getConstant(16, DL, MVT::i32), C8);
2340-
SDValue E0123 =
2341-
DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2342-
DAG.getAnyExtOrTrunc(Op->getOperand(3), DL, MVT::i32),
2343-
E012, DAG.getConstant(24, DL, MVT::i32), C8);
2344-
return DAG.getNode(ISD::BITCAST, DL, VT, E0123);
2345-
}
2346-
return Op;
2331+
auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2332+
uint64_t SelectionValue) -> SDValue {
2333+
SDValue L = Left;
2334+
SDValue R = Right;
2335+
if (Cast) {
2336+
L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
2337+
R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
2338+
}
2339+
return DAG.getNode(
2340+
NVPTXISD::PRMT, DL, MVT::v4i8,
2341+
{L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
2342+
DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2343+
};
2344+
auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
2345+
auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
2346+
auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2347+
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
23472348
}
23482349

23492350
// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2374,8 +2375,8 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
23742375
} else {
23752376
llvm_unreachable("Unsupported type");
23762377
}
2377-
SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
2378-
return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
2378+
SDValue Const = DAG.getConstant(Value, DL, MVT::i32);
2379+
return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const);
23792380
}
23802381

23812382
SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,

‎llvm/test/CodeGen/NVPTX/i8x4-instructions.ll

Lines changed: 309 additions & 305 deletions
Large diffs are not rendered by default.

‎llvm/test/CodeGen/NVPTX/sext-setcc.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,35 @@ define <4 x i8> @sext_setcc_v4i1_to_v4i8(ptr %p) {
3333
; CHECK: {
3434
; CHECK-NEXT: .reg .pred %p<5>;
3535
; CHECK-NEXT: .reg .b16 %rs<9>;
36-
; CHECK-NEXT: .reg .b32 %r<14>;
36+
; CHECK-NEXT: .reg .b32 %r<13>;
3737
; CHECK-NEXT: .reg .b64 %rd<2>;
3838
; CHECK-EMPTY:
3939
; CHECK-NEXT: // %bb.0: // %entry
4040
; CHECK-NEXT: ld.param.u64 %rd1, [sext_setcc_v4i1_to_v4i8_param_0];
4141
; CHECK-NEXT: ld.u32 %r1, [%rd1];
42-
; CHECK-NEXT: bfe.u32 %r2, %r1, 24, 8;
42+
; CHECK-NEXT: bfe.u32 %r2, %r1, 0, 8;
4343
; CHECK-NEXT: cvt.u16.u32 %rs1, %r2;
4444
; CHECK-NEXT: and.b16 %rs2, %rs1, 255;
4545
; CHECK-NEXT: setp.eq.s16 %p1, %rs2, 0;
46-
; CHECK-NEXT: bfe.u32 %r3, %r1, 16, 8;
46+
; CHECK-NEXT: bfe.u32 %r3, %r1, 8, 8;
4747
; CHECK-NEXT: cvt.u16.u32 %rs3, %r3;
4848
; CHECK-NEXT: and.b16 %rs4, %rs3, 255;
4949
; CHECK-NEXT: setp.eq.s16 %p2, %rs4, 0;
50-
; CHECK-NEXT: bfe.u32 %r4, %r1, 8, 8;
50+
; CHECK-NEXT: bfe.u32 %r4, %r1, 16, 8;
5151
; CHECK-NEXT: cvt.u16.u32 %rs5, %r4;
5252
; CHECK-NEXT: and.b16 %rs6, %rs5, 255;
5353
; CHECK-NEXT: setp.eq.s16 %p3, %rs6, 0;
54-
; CHECK-NEXT: bfe.u32 %r5, %r1, 0, 8;
54+
; CHECK-NEXT: bfe.u32 %r5, %r1, 24, 8;
5555
; CHECK-NEXT: cvt.u16.u32 %rs7, %r5;
5656
; CHECK-NEXT: and.b16 %rs8, %rs7, 255;
5757
; CHECK-NEXT: setp.eq.s16 %p4, %rs8, 0;
5858
; CHECK-NEXT: selp.s32 %r6, -1, 0, %p4;
5959
; CHECK-NEXT: selp.s32 %r7, -1, 0, %p3;
60-
; CHECK-NEXT: bfi.b32 %r8, %r7, %r6, 8, 8;
60+
; CHECK-NEXT: prmt.b32 %r8, %r7, %r6, 13120;
6161
; CHECK-NEXT: selp.s32 %r9, -1, 0, %p2;
62-
; CHECK-NEXT: bfi.b32 %r10, %r9, %r8, 16, 8;
63-
; CHECK-NEXT: selp.s32 %r11, -1, 0, %p1;
64-
; CHECK-NEXT: bfi.b32 %r12, %r11, %r10, 24, 8;
62+
; CHECK-NEXT: selp.s32 %r10, -1, 0, %p1;
63+
; CHECK-NEXT: prmt.b32 %r11, %r10, %r9, 13120;
64+
; CHECK-NEXT: prmt.b32 %r12, %r11, %r8, 21520;
6565
; CHECK-NEXT: st.param.b32 [func_retval0], %r12;
6666
; CHECK-NEXT: ret;
6767
entry:

0 commit comments

Comments
 (0)
Please sign in to comment.