[MLIR] Avoid some pointer element type accesses

Determine the element type from the MLIR LLVMPointerType, rather
than the LLVM PointerType.
This commit is contained in:
Nikita Popov 2022-03-29 18:21:39 +02:00
parent 8a72391f60
commit ea043ea183
2 changed files with 14 additions and 9 deletions

View File

@ -280,9 +280,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
return builder.CreateCall(
moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
auto *calleeType = operandsRef.front()->getType();
auto *calleeFunctionType =
cast<llvm::FunctionType>(calleeType->getPointerElementType());
auto calleeType =
op.getOperands().front().getType().cast<LLVMPointerType>();
auto *calleeFunctionType = cast<llvm::FunctionType>(
moduleTranslation.convertType(calleeType.getElementType()));
return builder.CreateCall(calleeFunctionType, operandsRef.front(),
operandsRef.drop_front());
};
@ -367,9 +368,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef);
} else {
auto *calleeType = operandsRef.front()->getType();
auto *calleeFunctionType =
cast<llvm::FunctionType>(calleeType->getPointerElementType());
auto calleeType =
invOp.getCalleeOperands().front().getType().cast<LLVMPointerType>();
auto *calleeFunctionType = cast<llvm::FunctionType>(
moduleTranslation.convertType(calleeType.getElementType()));
result = builder.CreateInvoke(
calleeFunctionType, operandsRef.front(),
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),

View File

@ -874,11 +874,14 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
if (owningAtomicReductionGens[i])
atomicGen = owningAtomicReductionGens[i];
auto reductionType =
loop.reduction_vars()[i].getType().cast<LLVM::LLVMPointerType>();
llvm::Value *variable =
moduleTranslation.lookupValue(loop.reduction_vars()[i]);
reductionInfos.push_back({variable->getType()->getPointerElementType(),
variable, privateReductionVariables[i],
owningReductionGens[i], atomicGen});
reductionInfos.push_back(
{moduleTranslation.convertType(reductionType.getElementType()),
variable, privateReductionVariables[i], owningReductionGens[i],
atomicGen});
}
// The call to createReductions below expects the block to have a