From f678fc7660b36ce0ad6ce4f05eaa28f3e9fdedb5 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Fri, 10 Jan 2020 10:13:45 -0800 Subject: [PATCH] [LegalizeVectorOps] Improve handling of multi-result operations. This system wasn't very well designed for multi-result nodes. As a consequence they weren't consistently registered in the LegalizedNodes map leading to nodes being revisited for different results. I've removed the "Result" variable from the main LegalizeOp method and used a SDNode* instead. The result number from the incoming Op SDValue is only used for deciding which result to return to the caller. When LegalizeOp is called it should always register a legalized result for all of its results. Future calls for any other result should be pulled for the LegalizedNodes map. Legal nodes will now register all of their results in the map instead of just the one we were called for. The Expand and Promote handling to use a vector of results similar to LegalizeDAG. Each of the new results is then re-legalized and logged in the LegalizedNodes map for all of the Results for the node being legalized. None of the handles register their own results now. And none call ReplaceAllUsesOfValueWith now. Custom handling now always passes result number 0 to LowerOperation. This matches what LegalizeDAG does. Since the introduction of STRICT nodes, I've encountered several issues with X86's custom handling being called with an SDValue pointing at the chain and our custom handlers using that to get a VT instead of result 0. This should prevent us from having any more of those issues. On return we will update the LegalizedNodes map for all results so we shouldn't call the custom handler again for each result number. I want to push SDNode* further into the Expand and Promote handlers, but I've left that for a follow to keep this patch size down. I've created a dummy SDValue(Node, 0) to keep the handlers working. Differential Revision: https://reviews.llvm.org/D72224 --- .../SelectionDAG/LegalizeVectorOps.cpp | 446 +++++++++++------- llvm/test/CodeGen/X86/avx512-cmp.ll | 36 ++ 2 files changed, 308 insertions(+), 174 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 13813008eff0..557bf495c85d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -75,7 +75,17 @@ class VectorLegalizer { SDValue LegalizeOp(SDValue Op); /// Assuming the node is legal, "legalize" the results. - SDValue TranslateLegalizeResults(SDValue Op, SDValue Result); + SDValue TranslateLegalizeResults(SDValue Op, SDNode *Result); + + /// Make sure Results are legal and update the translation cache. + SDValue RecursivelyLegalizeResults(SDValue Op, + MutableArrayRef Results); + + /// Wrapper to interface LowerOperation with a vector of Results. + /// Returns false if the target wants to use default expansion. Otherwise + /// returns true. If return is true and the Results are empty, then the + /// target wants to keep the input node as is. + bool LowerOperationWrapper(SDNode *N, SmallVectorImpl &Results); /// Implements unrolling a VSETCC. SDValue UnrollVSETCC(SDValue Op); @@ -84,15 +94,15 @@ class VectorLegalizer { /// /// This is just a high-level routine to dispatch to specific code paths for /// operations to legalize them. - SDValue Expand(SDValue Op); + void Expand(SDNode *Node, SmallVectorImpl &Results); /// Implements expansion for FP_TO_UINT; falls back to UnrollVectorOp if /// FP_TO_SINT isn't legal. - SDValue ExpandFP_TO_UINT(SDValue Op); + void ExpandFP_TO_UINT(SDValue Op, SmallVectorImpl &Results); /// Implements expansion for UINT_TO_FLOAT; falls back to UnrollVectorOp if /// SINT_TO_FLOAT and SHR on vectors isn't legal. - SDValue ExpandUINT_TO_FLOAT(SDValue Op); + void ExpandUINT_TO_FLOAT(SDValue Op, SmallVectorImpl &Results); /// Implement expansion for SIGN_EXTEND_INREG using SRL and SRA. SDValue ExpandSEXTINREG(SDValue Op); @@ -130,8 +140,8 @@ class VectorLegalizer { /// supported by the target. SDValue ExpandVSELECT(SDValue Op); SDValue ExpandSELECT(SDValue Op); - std::pair ExpandLoad(SDValue Op); - SDValue ExpandStore(SDValue Op); + std::pair ExpandLoad(SDNode *N); + SDValue ExpandStore(SDNode *N); SDValue ExpandFNEG(SDValue Op); SDValue ExpandFSUB(SDValue Op); SDValue ExpandBITREVERSE(SDValue Op); @@ -141,32 +151,33 @@ class VectorLegalizer { SDValue ExpandFunnelShift(SDValue Op); SDValue ExpandROT(SDValue Op); SDValue ExpandFMINNUM_FMAXNUM(SDValue Op); - SDValue ExpandUADDSUBO(SDValue Op); - SDValue ExpandSADDSUBO(SDValue Op); - SDValue ExpandMULO(SDValue Op); + void ExpandUADDSUBO(SDValue Op, SmallVectorImpl &Results); + void ExpandSADDSUBO(SDValue Op, SmallVectorImpl &Results); + void ExpandMULO(SDValue Op, SmallVectorImpl &Results); SDValue ExpandAddSubSat(SDValue Op); SDValue ExpandFixedPointMul(SDValue Op); SDValue ExpandFixedPointDiv(SDValue Op); SDValue ExpandStrictFPOp(SDValue Op); + void ExpandStrictFPOp(SDValue Op, SmallVectorImpl &Results); - SDValue UnrollStrictFPOp(SDValue Op); + void UnrollStrictFPOp(SDValue Op, SmallVectorImpl &Results); /// Implements vector promotion. /// /// This is essentially just bitcasting the operands to a different type and /// bitcasting the result back to the original type. - SDValue Promote(SDValue Op); + void Promote(SDNode *Node, SmallVectorImpl &Results); /// Implements [SU]INT_TO_FP vector promotion. /// /// This is a [zs]ext of the input operand to a larger integer type. - SDValue PromoteINT_TO_FP(SDValue Op); + void PromoteINT_TO_FP(SDValue Op, SmallVectorImpl &Results); /// Implements FP_TO_[SU]INT vector promotion of the result type. /// /// It is promoted to a larger integer type. The result is then /// truncated back to the original type. - SDValue PromoteFP_TO_INT(SDValue Op); + void PromoteFP_TO_INT(SDValue Op, SmallVectorImpl &Results); public: VectorLegalizer(SelectionDAG& dag) : @@ -222,11 +233,27 @@ bool VectorLegalizer::Run() { return Changed; } -SDValue VectorLegalizer::TranslateLegalizeResults(SDValue Op, SDValue Result) { +SDValue VectorLegalizer::TranslateLegalizeResults(SDValue Op, SDNode *Result) { + assert(Op->getNumValues() == Result->getNumValues() && + "Unexpected number of results"); // Generic legalization: just pass the operand through. - for (unsigned i = 0, e = Op.getNode()->getNumValues(); i != e; ++i) - AddLegalizedOperand(Op.getValue(i), Result.getValue(i)); - return Result.getValue(Op.getResNo()); + for (unsigned i = 0, e = Op->getNumValues(); i != e; ++i) + AddLegalizedOperand(Op.getValue(i), SDValue(Result, i)); + return SDValue(Result, Op.getResNo()); +} + +SDValue +VectorLegalizer::RecursivelyLegalizeResults(SDValue Op, + MutableArrayRef Results) { + assert(Results.size() == Op->getNumValues() && + "Unexpected number of results"); + // Make sure that the generated code is itself legal. + for (unsigned i = 0, e = Results.size(); i != e; ++i) { + Results[i] = LegalizeOp(Results[i]); + AddLegalizedOperand(Op.getValue(i), Results[i]); + } + + return Results[Op.getResNo()]; } SDValue VectorLegalizer::LegalizeOp(SDValue Op) { @@ -235,18 +262,15 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { DenseMap::iterator I = LegalizedNodes.find(Op); if (I != LegalizedNodes.end()) return I->second; - SDNode* Node = Op.getNode(); - // Legalize the operands SmallVector Ops; - for (const SDValue &Op : Node->op_values()) - Ops.push_back(LegalizeOp(Op)); + for (const SDValue &Oper : Op->op_values()) + Ops.push_back(LegalizeOp(Oper)); - SDValue Result = SDValue(DAG.UpdateNodeOperands(Op.getNode(), Ops), - Op.getResNo()); + SDNode *Node = DAG.UpdateNodeOperands(Op.getNode(), Ops); if (Op.getOpcode() == ISD::LOAD) { - LoadSDNode *LD = cast(Op.getNode()); + LoadSDNode *LD = cast(Node); ISD::LoadExtType ExtType = LD->getExtensionType(); if (LD->getMemoryVT().isVector() && ExtType != ISD::NON_EXTLOAD) { LLVM_DEBUG(dbgs() << "\nLegalizing extending vector load: "; @@ -255,22 +279,21 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { LD->getMemoryVT())) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Legal: - return TranslateLegalizeResults(Op, Result); - case TargetLowering::Custom: - if (SDValue Lowered = TLI.LowerOperation(Result, DAG)) { - assert(Lowered->getNumValues() == Op->getNumValues() && - "Unexpected number of results"); - if (Lowered != Result) { - // Make sure the new code is also legal. - Lowered = LegalizeOp(Lowered); - Changed = true; - } - return TranslateLegalizeResults(Op, Lowered); + return TranslateLegalizeResults(Op, Node); + case TargetLowering::Custom: { + SmallVector ResultVals; + if (LowerOperationWrapper(Node, ResultVals)) { + if (ResultVals.empty()) + return TranslateLegalizeResults(Op, Node); + + Changed = true; + return RecursivelyLegalizeResults(Op, ResultVals); } LLVM_FALLTHROUGH; + } case TargetLowering::Expand: { Changed = true; - std::pair Tmp = ExpandLoad(Result); + std::pair Tmp = ExpandLoad(Node); AddLegalizedOperand(Op.getValue(0), Tmp.first); AddLegalizedOperand(Op.getValue(1), Tmp.second); return Op.getResNo() ? Tmp.first : Tmp.second; @@ -278,7 +301,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { } } } else if (Op.getOpcode() == ISD::STORE) { - StoreSDNode *ST = cast(Op.getNode()); + StoreSDNode *ST = cast(Node); EVT StVT = ST->getMemoryVT(); MVT ValVT = ST->getValue().getSimpleValueType(); if (StVT.isVector() && ST->isTruncatingStore()) { @@ -287,19 +310,21 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { switch (TLI.getTruncStoreAction(ValVT, StVT)) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Legal: - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); case TargetLowering::Custom: { - SDValue Lowered = TLI.LowerOperation(Result, DAG); - if (Lowered != Result) { - // Make sure the new code is also legal. - Lowered = LegalizeOp(Lowered); + SmallVector ResultVals; + if (LowerOperationWrapper(Node, ResultVals)) { + if (ResultVals.empty()) + return TranslateLegalizeResults(Op, Node); + Changed = true; + return RecursivelyLegalizeResults(Op, ResultVals); } - return TranslateLegalizeResults(Op, Lowered); + LLVM_FALLTHROUGH; } case TargetLowering::Expand: { Changed = true; - SDValue Chain = ExpandStore(Result); + SDValue Chain = ExpandStore(Node); AddLegalizedOperand(Op, Chain); return Chain; } @@ -310,17 +335,17 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { bool HasVectorValueOrOp = false; for (auto J = Node->value_begin(), E = Node->value_end(); J != E; ++J) HasVectorValueOrOp |= J->isVector(); - for (const SDValue &Op : Node->op_values()) - HasVectorValueOrOp |= Op.getValueType().isVector(); + for (const SDValue &Oper : Node->op_values()) + HasVectorValueOrOp |= Oper.getValueType().isVector(); if (!HasVectorValueOrOp) - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); TargetLowering::LegalizeAction Action = TargetLowering::Legal; EVT ValVT; switch (Op.getOpcode()) { default: - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \ case ISD::STRICT_##DAGN: #include "llvm/IR/ConstrainedOps.def" @@ -473,42 +498,70 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { LLVM_DEBUG(dbgs() << "\nLegalizing vector op: "; Node->dump(&DAG)); + SmallVector ResultVals; switch (Action) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Promote: - Result = Promote(Op); - Changed = true; + LLVM_DEBUG(dbgs() << "Promoting\n"); + Promote(Node, ResultVals); + assert(!ResultVals.empty() && "No results for promotion?"); break; case TargetLowering::Legal: LLVM_DEBUG(dbgs() << "Legal node: nothing to do\n"); break; - case TargetLowering::Custom: { + case TargetLowering::Custom: LLVM_DEBUG(dbgs() << "Trying custom legalization\n"); - if (SDValue Tmp1 = TLI.LowerOperation(Op, DAG)) { - LLVM_DEBUG(dbgs() << "Successfully custom legalized node\n"); - Result = Tmp1; + if (LowerOperationWrapper(Node, ResultVals)) break; - } LLVM_DEBUG(dbgs() << "Could not custom legalize node\n"); LLVM_FALLTHROUGH; - } case TargetLowering::Expand: - Result = Expand(Op); + LLVM_DEBUG(dbgs() << "Expanding\n"); + Expand(Node, ResultVals); + break; } - // Make sure that the generated code is itself legal. - if (Result != Op) { - Result = LegalizeOp(Result); - Changed = true; - } + if (ResultVals.empty()) + return TranslateLegalizeResults(Op, Node); - // Note that LegalizeOp may be reentered even from single-use nodes, which - // means that we always must cache transformed nodes. - AddLegalizedOperand(Op, Result); - return Result; + Changed = true; + return RecursivelyLegalizeResults(Op, ResultVals); } -SDValue VectorLegalizer::Promote(SDValue Op) { +// FIME: This is very similar to the X86 override of +// TargetLowering::LowerOperationWrapper. Can we merge them somehow? +bool VectorLegalizer::LowerOperationWrapper(SDNode *Node, + SmallVectorImpl &Results) { + SDValue Res = TLI.LowerOperation(SDValue(Node, 0), DAG); + + if (!Res.getNode()) + return false; + + if (Res == SDValue(Node, 0)) + return true; + + // If the original node has one result, take the return value from + // LowerOperation as is. It might not be result number 0. + if (Node->getNumValues() == 1) { + Results.push_back(Res); + return true; + } + + // If the original node has multiple results, then the return node should + // have the same number of results. + assert((Node->getNumValues() == Res->getNumValues()) && + "Lowering returned the wrong number of results!"); + + // Places new result values base on N result number. + for (unsigned I = 0, E = Node->getNumValues(); I != E; ++I) + Results.push_back(Res.getValue(I)); + + return true; +} + +void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl &Results) { + SDValue Op(Node, 0); // FIXME: Use Node throughout. + // For a few operations there is a specific concept for promotion based on // the operand's type. switch (Op.getOpcode()) { @@ -517,13 +570,15 @@ SDValue VectorLegalizer::Promote(SDValue Op) { case ISD::STRICT_SINT_TO_FP: case ISD::STRICT_UINT_TO_FP: // "Promote" the operation by extending the operand. - return PromoteINT_TO_FP(Op); + PromoteINT_TO_FP(Op, Results); + return; case ISD::FP_TO_UINT: case ISD::FP_TO_SINT: case ISD::STRICT_FP_TO_UINT: case ISD::STRICT_FP_TO_SINT: // Promote the operation by extending the operand. - return PromoteFP_TO_INT(Op); + PromoteFP_TO_INT(Op, Results); + return; case ISD::FP_ROUND: case ISD::FP_EXTEND: // These operations are used to do promotion so they can't be promoted @@ -558,15 +613,20 @@ SDValue VectorLegalizer::Promote(SDValue Op) { } Op = DAG.getNode(Op.getOpcode(), dl, NVT, Operands, Op.getNode()->getFlags()); + + SDValue Res; if ((VT.isFloatingPoint() && NVT.isFloatingPoint()) || (VT.isVector() && VT.getVectorElementType().isFloatingPoint() && NVT.isVector() && NVT.getVectorElementType().isFloatingPoint())) - return DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0, dl)); + Res = DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0, dl)); else - return DAG.getNode(ISD::BITCAST, dl, VT, Op); + Res = DAG.getNode(ISD::BITCAST, dl, VT, Op); + + Results.push_back(Res); } -SDValue VectorLegalizer::PromoteINT_TO_FP(SDValue Op) { +void VectorLegalizer::PromoteINT_TO_FP(SDValue Op, + SmallVectorImpl &Results) { // INT_TO_FP operations may require the input operand be promoted even // when the type is otherwise legal. bool IsStrict = Op->isStrictFPOpcode(); @@ -589,18 +649,24 @@ SDValue VectorLegalizer::PromoteINT_TO_FP(SDValue Op) { Operands[j] = Op.getOperand(j); } - if (IsStrict) - return DAG.getNode(Op.getOpcode(), dl, {Op.getValueType(), MVT::Other}, - Operands); + if (IsStrict) { + SDValue Res = DAG.getNode(Op.getOpcode(), dl, + {Op.getValueType(), MVT::Other}, Operands); + Results.push_back(Res); + Results.push_back(Res.getValue(1)); + return; + } - return DAG.getNode(Op.getOpcode(), dl, Op.getValueType(), Operands); + SDValue Res = DAG.getNode(Op.getOpcode(), dl, Op.getValueType(), Operands); + Results.push_back(Res); } // For FP_TO_INT we promote the result type to a vector type with wider // elements and then truncate the result. This is different from the default // PromoteVector which uses bitcast to promote thus assumning that the // promoted vector type has the same overall size. -SDValue VectorLegalizer::PromoteFP_TO_INT(SDValue Op) { +void VectorLegalizer::PromoteFP_TO_INT(SDValue Op, + SmallVectorImpl &Results) { MVT VT = Op.getSimpleValueType(); MVT NVT = TLI.getTypeToPromoteTo(Op.getOpcode(), VT); bool IsStrict = Op->isStrictFPOpcode(); @@ -639,14 +705,13 @@ SDValue VectorLegalizer::PromoteFP_TO_INT(SDValue Op) { Promoted = DAG.getNode(NewOpc, dl, NVT, Promoted, DAG.getValueType(VT.getScalarType())); Promoted = DAG.getNode(ISD::TRUNCATE, dl, VT, Promoted); + Results.push_back(Promoted); if (IsStrict) - return DAG.getMergeValues({Promoted, Chain}, dl); - - return Promoted; + Results.push_back(Chain); } -std::pair VectorLegalizer::ExpandLoad(SDValue Op) { - LoadSDNode *LD = cast(Op.getNode()); +std::pair VectorLegalizer::ExpandLoad(SDNode *N) { + LoadSDNode *LD = cast(N); EVT SrcVT = LD->getMemoryVT(); EVT SrcEltVT = SrcVT.getScalarType(); @@ -655,7 +720,7 @@ std::pair VectorLegalizer::ExpandLoad(SDValue Op) { SDValue NewChain; SDValue Value; if (SrcVT.getVectorNumElements() > 1 && !SrcEltVT.isByteSized()) { - SDLoc dl(Op); + SDLoc dl(N); SmallVector Vals; SmallVector LoadChains; @@ -767,7 +832,7 @@ std::pair VectorLegalizer::ExpandLoad(SDValue Op) { } NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, LoadChains); - Value = DAG.getBuildVector(Op.getNode()->getValueType(0), dl, Vals); + Value = DAG.getBuildVector(N->getValueType(0), dl, Vals); } else { std::tie(Value, NewChain) = TLI.scalarizeVectorLoad(LD, DAG); } @@ -775,90 +840,122 @@ std::pair VectorLegalizer::ExpandLoad(SDValue Op) { return std::make_pair(Value, NewChain); } -SDValue VectorLegalizer::ExpandStore(SDValue Op) { - StoreSDNode *ST = cast(Op.getNode()); +SDValue VectorLegalizer::ExpandStore(SDNode *N) { + StoreSDNode *ST = cast(N); SDValue TF = TLI.scalarizeVectorStore(ST, DAG); return TF; } -SDValue VectorLegalizer::Expand(SDValue Op) { +void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl &Results) { + SDValue Op(Node, 0); // FIXME: Just pass Node to all the expanders. + switch (Op->getOpcode()) { case ISD::SIGN_EXTEND_INREG: - return ExpandSEXTINREG(Op); + Results.push_back(ExpandSEXTINREG(Op)); + return; case ISD::ANY_EXTEND_VECTOR_INREG: - return ExpandANY_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandANY_EXTEND_VECTOR_INREG(Op)); + return; case ISD::SIGN_EXTEND_VECTOR_INREG: - return ExpandSIGN_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandSIGN_EXTEND_VECTOR_INREG(Op)); + return; case ISD::ZERO_EXTEND_VECTOR_INREG: - return ExpandZERO_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandZERO_EXTEND_VECTOR_INREG(Op)); + return; case ISD::BSWAP: - return ExpandBSWAP(Op); + Results.push_back(ExpandBSWAP(Op)); + return; case ISD::VSELECT: - return ExpandVSELECT(Op); + Results.push_back(ExpandVSELECT(Op)); + return; case ISD::SELECT: - return ExpandSELECT(Op); + Results.push_back(ExpandSELECT(Op)); + return; case ISD::FP_TO_UINT: - return ExpandFP_TO_UINT(Op); + ExpandFP_TO_UINT(Op, Results); + return; case ISD::UINT_TO_FP: - return ExpandUINT_TO_FLOAT(Op); + ExpandUINT_TO_FLOAT(Op, Results); + return; case ISD::FNEG: - return ExpandFNEG(Op); + Results.push_back(ExpandFNEG(Op)); + return; case ISD::FSUB: - return ExpandFSUB(Op); + if (SDValue Tmp = ExpandFSUB(Op)) + Results.push_back(Tmp); + return; case ISD::SETCC: - return UnrollVSETCC(Op); + Results.push_back(UnrollVSETCC(Op)); + return; case ISD::ABS: - return ExpandABS(Op); + Results.push_back(ExpandABS(Op)); + return; case ISD::BITREVERSE: - return ExpandBITREVERSE(Op); + if (SDValue Tmp = ExpandBITREVERSE(Op)) + Results.push_back(Tmp); + return; case ISD::CTPOP: - return ExpandCTPOP(Op); + Results.push_back(ExpandCTPOP(Op)); + return; case ISD::CTLZ: case ISD::CTLZ_ZERO_UNDEF: - return ExpandCTLZ(Op); + Results.push_back(ExpandCTLZ(Op)); + return; case ISD::CTTZ: case ISD::CTTZ_ZERO_UNDEF: - return ExpandCTTZ(Op); + Results.push_back(ExpandCTTZ(Op)); + return; case ISD::FSHL: case ISD::FSHR: - return ExpandFunnelShift(Op); + Results.push_back(ExpandFunnelShift(Op)); + return; case ISD::ROTL: case ISD::ROTR: - return ExpandROT(Op); + Results.push_back(ExpandROT(Op)); + return; case ISD::FMINNUM: case ISD::FMAXNUM: - return ExpandFMINNUM_FMAXNUM(Op); + Results.push_back(ExpandFMINNUM_FMAXNUM(Op)); + return; case ISD::UADDO: case ISD::USUBO: - return ExpandUADDSUBO(Op); + ExpandUADDSUBO(Op, Results); + return; case ISD::SADDO: case ISD::SSUBO: - return ExpandSADDSUBO(Op); + ExpandSADDSUBO(Op, Results); + return; case ISD::UMULO: case ISD::SMULO: - return ExpandMULO(Op); + ExpandMULO(Op, Results); + return; case ISD::USUBSAT: case ISD::SSUBSAT: case ISD::UADDSAT: case ISD::SADDSAT: - return ExpandAddSubSat(Op); + Results.push_back(ExpandAddSubSat(Op)); + return; case ISD::SMULFIX: case ISD::UMULFIX: - return ExpandFixedPointMul(Op); + Results.push_back(ExpandFixedPointMul(Op)); + return; case ISD::SMULFIXSAT: case ISD::UMULFIXSAT: // FIXME: We do not expand SMULFIXSAT/UMULFIXSAT here yet, not sure exactly // why. Maybe it results in worse codegen compared to the unroll for some // targets? This should probably be investigated. And if we still prefer to // unroll an explanation could be helpful. - return DAG.UnrollVectorOp(Op.getNode()); + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + return; case ISD::SDIVFIX: case ISD::UDIVFIX: - return ExpandFixedPointDiv(Op); + Results.push_back(ExpandFixedPointDiv(Op)); + return; #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \ case ISD::STRICT_##DAGN: #include "llvm/IR/ConstrainedOps.def" - return ExpandStrictFPOp(Op); + ExpandStrictFPOp(Op, Results); + return; case ISD::VECREDUCE_ADD: case ISD::VECREDUCE_MUL: case ISD::VECREDUCE_AND: @@ -872,9 +969,11 @@ SDValue VectorLegalizer::Expand(SDValue Op) { case ISD::VECREDUCE_FMUL: case ISD::VECREDUCE_FMAX: case ISD::VECREDUCE_FMIN: - return TLI.expandVecReduce(Op.getNode(), DAG); + Results.push_back(TLI.expandVecReduce(Op.getNode(), DAG)); + return; default: - return DAG.UnrollVectorOp(Op.getNode()); + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + return; } } @@ -1120,7 +1219,7 @@ SDValue VectorLegalizer::ExpandBITREVERSE(SDValue Op) { return DAG.UnrollVectorOp(Op.getNode()); // Let LegalizeDAG handle this later. - return Op; + return SDValue(); } SDValue VectorLegalizer::ExpandVSELECT(SDValue Op) { @@ -1180,23 +1279,28 @@ SDValue VectorLegalizer::ExpandABS(SDValue Op) { return DAG.UnrollVectorOp(Op.getNode()); } -SDValue VectorLegalizer::ExpandFP_TO_UINT(SDValue Op) { +void VectorLegalizer::ExpandFP_TO_UINT(SDValue Op, + SmallVectorImpl &Results) { // Attempt to expand using TargetLowering. SDValue Result, Chain; if (TLI.expandFP_TO_UINT(Op.getNode(), Result, Chain, DAG)) { + Results.push_back(Result); if (Op->isStrictFPOpcode()) - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Chain); - return Result; + Results.push_back(Chain); + return; } // Otherwise go ahead and unroll. - if (Op->isStrictFPOpcode()) - return UnrollStrictFPOp(Op); - return DAG.UnrollVectorOp(Op.getNode()); + if (Op->isStrictFPOpcode()) { + UnrollStrictFPOp(Op, Results); + return; + } + + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); } -SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) { +void VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op, + SmallVectorImpl &Results) { bool IsStrict = Op.getNode()->isStrictFPOpcode(); unsigned OpNo = IsStrict ? 1 : 0; SDValue Src = Op.getOperand(OpNo); @@ -1207,10 +1311,10 @@ SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) { SDValue Result; SDValue Chain; if (TLI.expandUINT_TO_FP(Op.getNode(), Result, Chain, DAG)) { + Results.push_back(Result); if (IsStrict) - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Chain); - return Result; + Results.push_back(Chain); + return; } // Make sure that the SINT_TO_FP and SRL instructions are available. @@ -1219,9 +1323,13 @@ SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) { (IsStrict && TLI.getOperationAction(ISD::STRICT_SINT_TO_FP, VT) == TargetLowering::Expand)) || TLI.getOperationAction(ISD::SRL, VT) == TargetLowering::Expand) { - if (IsStrict) - return UnrollStrictFPOp(Op); - return DAG.UnrollVectorOp(Op.getNode()); + if (IsStrict) { + UnrollStrictFPOp(Op, Results); + return; + } + + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + return; } unsigned BW = VT.getScalarSizeInBits(); @@ -1261,9 +1369,9 @@ SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) { DAG.getNode(ISD::STRICT_FADD, DL, {Op.getValueType(), MVT::Other}, {SDValue(fLO.getNode(), 1), fHI, fLO}); - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), SDValue(Result.getNode(), 1)); - return Result; + Results.push_back(Result); + Results.push_back(Result.getValue(1)); + return; } // Convert hi and lo to floats @@ -1274,7 +1382,7 @@ SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) { SDValue fLO = DAG.getNode(ISD::SINT_TO_FP, DL, Op.getValueType(), LO); // Add the two halves - return DAG.getNode(ISD::FADD, DL, Op.getValueType(), fHI, fLO); + Results.push_back(DAG.getNode(ISD::FADD, DL, Op.getValueType(), fHI, fLO)); } SDValue VectorLegalizer::ExpandFNEG(SDValue Op) { @@ -1295,7 +1403,7 @@ SDValue VectorLegalizer::ExpandFSUB(SDValue Op) { EVT VT = Op.getValueType(); if (TLI.isOperationLegalOrCustom(ISD::FNEG, VT) && TLI.isOperationLegalOrCustom(ISD::FADD, VT)) - return Op; // Defer to LegalizeDAG + return SDValue(); // Defer to LegalizeDAG return DAG.UnrollVectorOp(Op.getNode()); } @@ -1346,44 +1454,30 @@ SDValue VectorLegalizer::ExpandFMINNUM_FMAXNUM(SDValue Op) { return DAG.UnrollVectorOp(Op.getNode()); } -SDValue VectorLegalizer::ExpandUADDSUBO(SDValue Op) { +void VectorLegalizer::ExpandUADDSUBO(SDValue Op, + SmallVectorImpl &Results) { SDValue Result, Overflow; TLI.expandUADDSUBO(Op.getNode(), Result, Overflow, DAG); - - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + Results.push_back(Result); + Results.push_back(Overflow); } -SDValue VectorLegalizer::ExpandSADDSUBO(SDValue Op) { +void VectorLegalizer::ExpandSADDSUBO(SDValue Op, + SmallVectorImpl &Results) { SDValue Result, Overflow; TLI.expandSADDSUBO(Op.getNode(), Result, Overflow, DAG); - - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + Results.push_back(Result); + Results.push_back(Overflow); } -SDValue VectorLegalizer::ExpandMULO(SDValue Op) { +void VectorLegalizer::ExpandMULO(SDValue Op, + SmallVectorImpl &Results) { SDValue Result, Overflow; if (!TLI.expandMULO(Op.getNode(), Result, Overflow, DAG)) std::tie(Result, Overflow) = DAG.UnrollVectorOverflowOp(Op.getNode()); - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + Results.push_back(Result); + Results.push_back(Overflow); } SDValue VectorLegalizer::ExpandAddSubSat(SDValue Op) { @@ -1406,16 +1500,22 @@ SDValue VectorLegalizer::ExpandFixedPointDiv(SDValue Op) { return DAG.UnrollVectorOp(N); } -SDValue VectorLegalizer::ExpandStrictFPOp(SDValue Op) { - if (Op.getOpcode() == ISD::STRICT_UINT_TO_FP) - return ExpandUINT_TO_FLOAT(Op); - if (Op.getOpcode() == ISD::STRICT_FP_TO_UINT) - return ExpandFP_TO_UINT(Op); +void VectorLegalizer::ExpandStrictFPOp(SDValue Op, + SmallVectorImpl &Results) { + if (Op.getOpcode() == ISD::STRICT_UINT_TO_FP) { + ExpandUINT_TO_FLOAT(Op, Results); + return; + } + if (Op.getOpcode() == ISD::STRICT_FP_TO_UINT) { + ExpandFP_TO_UINT(Op, Results); + return; + } - return UnrollStrictFPOp(Op); + UnrollStrictFPOp(Op, Results); } -SDValue VectorLegalizer::UnrollStrictFPOp(SDValue Op) { +void VectorLegalizer::UnrollStrictFPOp(SDValue Op, + SmallVectorImpl &Results) { EVT VT = Op.getValue(0).getValueType(); EVT EltVT = VT.getVectorElementType(); unsigned NumElems = VT.getVectorNumElements(); @@ -1472,10 +1572,8 @@ SDValue VectorLegalizer::UnrollStrictFPOp(SDValue Op) { SDValue Result = DAG.getBuildVector(VT, dl, OpValues); SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OpChains); - AddLegalizedOperand(Op.getValue(0), Result); - AddLegalizedOperand(Op.getValue(1), NewChain); - - return Op.getResNo() ? NewChain : Result; + Results.push_back(Result); + Results.push_back(NewChain); } SDValue VectorLegalizer::UnrollVSETCC(SDValue Op) { diff --git a/llvm/test/CodeGen/X86/avx512-cmp.ll b/llvm/test/CodeGen/X86/avx512-cmp.ll index 3f3141e8876c..bd902dde2a26 100644 --- a/llvm/test/CodeGen/X86/avx512-cmp.ll +++ b/llvm/test/CodeGen/X86/avx512-cmp.ll @@ -181,3 +181,39 @@ if.then.i: if.end.i: ret i32 6 } + +; This test previously caused an infinite loop in legalize vector ops. Due to +; CSE triggering on the call to UpdateNodeOperands and the resulting node not +; being passed to LowerOperation. The add is needed to force the zext into a +; sext on that path. The shuffle keeps the zext alive. The xor somehow +; influences the zext to be visited before the sext exposing the CSE opportunity +; for the sext since zext of setcc is custom legalized to a sext and shift. +define <8 x i32> @legalize_loop(<8 x double> %arg) { +; KNL-LABEL: legalize_loop: +; KNL: ## %bb.0: +; KNL-NEXT: vxorpd %xmm1, %xmm1, %xmm1 +; KNL-NEXT: vcmpnltpd %zmm0, %zmm1, %k1 +; KNL-NEXT: vpternlogd $255, %zmm0, %zmm0, %zmm0 {%k1} {z} +; KNL-NEXT: vpsrld $31, %ymm0, %ymm1 +; KNL-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[3,2,1,0,7,6,5,4] +; KNL-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,0,1] +; KNL-NEXT: vpsubd %ymm0, %ymm1, %ymm0 +; KNL-NEXT: retq +; +; SKX-LABEL: legalize_loop: +; SKX: ## %bb.0: +; SKX-NEXT: vxorpd %xmm1, %xmm1, %xmm1 +; SKX-NEXT: vcmpnltpd %zmm0, %zmm1, %k0 +; SKX-NEXT: vpmovm2d %k0, %ymm0 +; SKX-NEXT: vpsrld $31, %ymm0, %ymm1 +; SKX-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[3,2,1,0,7,6,5,4] +; SKX-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,0,1] +; SKX-NEXT: vpsubd %ymm0, %ymm1, %ymm0 +; SKX-NEXT: retq + %tmp = fcmp ogt <8 x double> %arg, zeroinitializer + %tmp1 = xor <8 x i1> %tmp, + %tmp2 = zext <8 x i1> %tmp1 to <8 x i32> + %tmp3 = shufflevector <8 x i32> %tmp2, <8 x i32> undef, <8 x i32> + %tmp4 = add <8 x i32> %tmp2, %tmp3 + ret <8 x i32> %tmp4 +}