[NVPTX] Handle addrspacecast constant expressions in aggregate initializers

We need to track if an AddrSpaceCast expression was seen when
generating an MCExpr for a ConstantExpr.  This change introduces a
custom lowerConstant method to the NVPTX asm printer that will create
NVPTXGenericMCSymbolRefExpr nodes at the appropriate places to encode
the information that a given symbol needs to be casted to a generic
address.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@236000 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Justin Holewinski 2015-04-28 17:18:30 +00:00
parent e48ac32ea2
commit 0292a66bb1
5 changed files with 270 additions and 2 deletions

View File

@ -1983,6 +1983,212 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) {
return false;
}
/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
/// a copy from AsmPrinter::lowerConstant, except customized to only handle
/// expressions that are representable in PTX and create
/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
const MCExpr *
NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
MCContext &Ctx = OutContext;
if (CV->isNullValue() || isa<UndefValue>(CV))
return MCConstantExpr::Create(0, Ctx);
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
return MCConstantExpr::Create(CI->getZExtValue(), Ctx);
if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
const MCSymbolRefExpr *Expr =
MCSymbolRefExpr::Create(getSymbol(GV), Ctx);
if (ProcessingGeneric) {
return NVPTXGenericMCSymbolRefExpr::Create(Expr, Ctx);
} else {
return Expr;
}
}
const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
if (!CE) {
llvm_unreachable("Unknown constant value to lower!");
}
switch (CE->getOpcode()) {
default:
// If the code isn't optimized, there may be outstanding folding
// opportunities. Attempt to fold the expression using DataLayout as a
// last resort before giving up.
if (Constant *C = ConstantFoldConstantExpression(CE, *TM.getDataLayout()))
if (C != CE)
return lowerConstantForGV(C, ProcessingGeneric);
// Otherwise report the problem to the user.
{
std::string S;
raw_string_ostream OS(S);
OS << "Unsupported expression in static initializer: ";
CE->printAsOperand(OS, /*PrintType=*/false,
!MF ? nullptr : MF->getFunction()->getParent());
report_fatal_error(OS.str());
}
case Instruction::AddrSpaceCast: {
// Strip the addrspacecast and pass along the operand
PointerType *DstTy = cast<PointerType>(CE->getType());
if (DstTy->getAddressSpace() == 0) {
return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
}
std::string S;
raw_string_ostream OS(S);
OS << "Unsupported expression in static initializer: ";
CE->printAsOperand(OS, /*PrintType=*/ false,
!MF ? 0 : MF->getFunction()->getParent());
report_fatal_error(OS.str());
}
case Instruction::GetElementPtr: {
const DataLayout &DL = *TM.getDataLayout();
// Generate a symbolic expression for the byte address
APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
ProcessingGeneric);
if (!OffsetAI)
return Base;
int64_t Offset = OffsetAI.getSExtValue();
return MCBinaryExpr::CreateAdd(Base, MCConstantExpr::Create(Offset, Ctx),
Ctx);
}
case Instruction::Trunc:
// We emit the value and depend on the assembler to truncate the generated
// expression properly. This is important for differences between
// blockaddress labels. Since the two labels are in the same function, it
// is reasonable to treat their delta as a 32-bit value.
// FALL THROUGH.
case Instruction::BitCast:
return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
case Instruction::IntToPtr: {
const DataLayout &DL = *TM.getDataLayout();
// Handle casts to pointers by changing them into casts to the appropriate
// integer type. This promotes constant folding and simplifies this code.
Constant *Op = CE->getOperand(0);
Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
false/*ZExt*/);
return lowerConstantForGV(Op, ProcessingGeneric);
}
case Instruction::PtrToInt: {
const DataLayout &DL = *TM.getDataLayout();
// Support only foldable casts to/from pointers that can be eliminated by
// changing the pointer to the appropriately sized integer type.
Constant *Op = CE->getOperand(0);
Type *Ty = CE->getType();
const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
// We can emit the pointer value into this slot if the slot is an
// integer slot equal to the size of the pointer.
if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
return OpExpr;
// Otherwise the pointer is smaller than the resultant integer, mask off
// the high bits so we are sure to get a proper truncation if the input is
// a constant expr.
unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
const MCExpr *MaskExpr = MCConstantExpr::Create(~0ULL >> (64-InBits), Ctx);
return MCBinaryExpr::CreateAnd(OpExpr, MaskExpr, Ctx);
}
// The MC library also has a right-shift operator, but it isn't consistently
// signed or unsigned between different targets.
case Instruction::Add: {
const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
switch (CE->getOpcode()) {
default: llvm_unreachable("Unknown binary operator constant cast expr");
case Instruction::Add: return MCBinaryExpr::CreateAdd(LHS, RHS, Ctx);
}
}
}
}
// Copy of MCExpr::print customized for NVPTX
void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
switch (Expr.getKind()) {
case MCExpr::Target:
return cast<MCTargetExpr>(&Expr)->PrintImpl(OS);
case MCExpr::Constant:
OS << cast<MCConstantExpr>(Expr).getValue();
return;
case MCExpr::SymbolRef: {
const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
const MCSymbol &Sym = SRE.getSymbol();
OS << Sym;
return;
}
case MCExpr::Unary: {
const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
switch (UE.getOpcode()) {
case MCUnaryExpr::LNot: OS << '!'; break;
case MCUnaryExpr::Minus: OS << '-'; break;
case MCUnaryExpr::Not: OS << '~'; break;
case MCUnaryExpr::Plus: OS << '+'; break;
}
printMCExpr(*UE.getSubExpr(), OS);
return;
}
case MCExpr::Binary: {
const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
// Only print parens around the LHS if it is non-trivial.
if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
printMCExpr(*BE.getLHS(), OS);
} else {
OS << '(';
printMCExpr(*BE.getLHS(), OS);
OS<< ')';
}
switch (BE.getOpcode()) {
case MCBinaryExpr::Add:
// Print "X-42" instead of "X+-42".
if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
if (RHSC->getValue() < 0) {
OS << RHSC->getValue();
return;
}
}
OS << '+';
break;
default: llvm_unreachable("Unhandled binary operator");
}
// Only print parens around the LHS if it is non-trivial.
if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
printMCExpr(*BE.getRHS(), OS);
} else {
OS << '(';
printMCExpr(*BE.getRHS(), OS);
OS << ')';
}
return;
}
}
llvm_unreachable("Invalid expression kind!");
}
/// PrintAsmOperand - Print out an operand for an inline asm expression.
///
bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,

View File

@ -169,8 +169,10 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
} else {
O << *Name;
}
} else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(v)) {
O << *AP.lowerConstant(Cexpr);
} else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
const MCExpr *Expr =
AP.lowerConstantForGV(cast<Constant>(CExpr), false);
AP.printMCExpr(*Expr, O);
} else
llvm_unreachable("symbol type unknown");
nSym++;
@ -241,6 +243,10 @@ private:
bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo,
unsigned AsmVariant, const char *ExtraCode,
raw_ostream &) override;
const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric);
void printMCExpr(const MCExpr &Expr, raw_ostream &OS);
protected:
bool doInitialization(Module &M) override;
bool doFinalization(Module &M) override;

View File

@ -45,3 +45,13 @@ void NVPTXFloatMCExpr::PrintImpl(raw_ostream &OS) const {
OS << std::string(NumHex - HexStr.length(), '0');
OS << utohexstr(API.getZExtValue());
}
const NVPTXGenericMCSymbolRefExpr*
NVPTXGenericMCSymbolRefExpr::Create(const MCSymbolRefExpr *SymExpr,
MCContext &Ctx) {
return new (Ctx) NVPTXGenericMCSymbolRefExpr(SymExpr);
}
void NVPTXGenericMCSymbolRefExpr::PrintImpl(raw_ostream &OS) const {
OS << "generic(" << *SymExpr << ")";
}

View File

@ -79,6 +79,50 @@ public:
return E->getKind() == MCExpr::Target;
}
};
/// A wrapper for MCSymbolRefExpr that tells the assembly printer that the
/// symbol should be enclosed by generic().
class NVPTXGenericMCSymbolRefExpr : public MCTargetExpr {
private:
const MCSymbolRefExpr *SymExpr;
explicit NVPTXGenericMCSymbolRefExpr(const MCSymbolRefExpr *_SymExpr)
: SymExpr(_SymExpr) {}
public:
/// @name Construction
/// @{
static const NVPTXGenericMCSymbolRefExpr
*Create(const MCSymbolRefExpr *SymExpr, MCContext &Ctx);
/// @}
/// @name Accessors
/// @{
/// getOpcode - Get the kind of this expression.
const MCSymbolRefExpr *getSymbolExpr() const { return SymExpr; }
/// @}
void PrintImpl(raw_ostream &OS) const;
bool EvaluateAsRelocatableImpl(MCValue &Res,
const MCAsmLayout *Layout,
const MCFixup *Fixup) const override {
return false;
}
void visitUsedExpr(MCStreamer &Streamer) const override {};
const MCSection *FindAssociatedSection() const override {
return nullptr;
}
// There are no TLS NVPTXMCExprs at the moment.
void fixELFSymbolsInTLSFixups(MCAssembler &Asm) const override {}
static bool classof(const MCExpr *E) {
return E->getKind() == MCExpr::Target;
}
};
} // end namespace llvm
#endif

View File

@ -4,8 +4,10 @@
; CHECK: .visible .global .align 4 .u32 g2 = generic(g);
; CHECK: .visible .global .align 4 .u32 g3 = g;
; CHECK: .visible .global .align 8 .u32 g4[2] = {0, generic(g)};
; CHECK: .visible .global .align 8 .u32 g5[2] = {0, generic(g)+8};
@g = addrspace(1) global i32 42
@g2 = addrspace(1) global i32* addrspacecast (i32 addrspace(1)* @g to i32*)
@g3 = addrspace(1) global i32 addrspace(1)* @g
@g4 = constant {i32*, i32*} {i32* null, i32* addrspacecast (i32 addrspace(1)* @g to i32*)}
@g5 = constant {i32*, i32*} {i32* null, i32* addrspacecast (i32 addrspace(1)* getelementptr (i32, i32 addrspace(1)* @g, i32 2) to i32*)}