[fir] TargetRewrite: Rewrite fir.address_of(func)

Rewrite AddrOfOp if taking the address of a function.

Differential Revision: https://reviews.llvm.org/D114925

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Diana Picus 2021-12-02 04:27:18 +00:00
parent 867cd948ac
commit 3fd250d258
3 changed files with 83 additions and 0 deletions

View File

@ -100,6 +100,10 @@ public:
} else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
if (!hasPortableSignature(dispatch.getFunctionType()))
convertCallOp(dispatch);
} else if (auto addr = dyn_cast<AddrOfOp>(op)) {
if (addr.getType().isa<mlir::FunctionType>() &&
!hasPortableSignature(addr.getType()))
convertAddrOp(addr);
}
});
@ -319,6 +323,55 @@ public:
newInTys.push_back(std::get<mlir::Type>(tup));
}
/// Taking the address of a function. Modify the signature as needed.
void convertAddrOp(AddrOfOp addrOp) {
rewriter->setInsertionPoint(addrOp);
auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
llvm::SmallVector<mlir::Type> newResTys;
llvm::SmallVector<mlir::Type> newInTys;
for (mlir::Type ty : addrTy.getResults()) {
llvm::TypeSwitch<mlir::Type>(ty)
.Case<fir::ComplexType>([&](fir::ComplexType ty) {
lowerComplexSignatureRes(ty, newResTys, newInTys);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
lowerComplexSignatureRes(ty, newResTys, newInTys);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
}
llvm::SmallVector<mlir::Type> trailingInTys;
for (mlir::Type ty : addrTy.getInputs()) {
llvm::TypeSwitch<mlir::Type>(ty)
.Case<BoxCharType>([&](BoxCharType box) {
if (noCharacterConversion) {
newInTys.push_back(box);
} else {
for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
auto argTy = std::get<mlir::Type>(tup);
llvm::SmallVector<mlir::Type> &vec =
attr.isAppend() ? trailingInTys : newInTys;
vec.push_back(argTy);
}
}
})
.Case<fir::ComplexType>([&](fir::ComplexType ty) {
lowerComplexSignatureArg(ty, newInTys);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
lowerComplexSignatureArg(ty, newInTys);
})
.Default([&](mlir::Type ty) { newInTys.push_back(ty); });
}
// append trailing input types
newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
// replace this op with a new one with the updated signature
auto newTy = rewriter->getFunctionType(newInTys, newResTys);
auto newOp =
rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol());
replaceOp(addrOp, newOp.getResult());
}
/// Convert the type signatures on all the functions present in the module.
/// As the type signature is being changed, this must also update the
/// function itself to use any new arguments, etc.

View File

@ -93,3 +93,13 @@ fir.global @name constant : !fir.char<1,9> {
//constant 1
fir.has_value %str : !fir.char<1,9>
}
// Test that we rewrite the fir.address_of operator
// INT32-LABEL: @addrof
// INT64-LABEL: @addrof
func @addrof() {
// INT32: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref<!fir.char<1,?>>, i32) -> ()
// INT64: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref<!fir.char<1,?>>, i64) -> ()
%f = fir.address_of(@boxcharcallee) : (!fir.boxchar<1>) -> ()
return
}

View File

@ -452,3 +452,23 @@ func private @mlircomplexf32(%z1: complex<f32>, %z2: complex<f32>) -> complex<f3
// PPC: return [[RES]] : tuple<f32, f32>
return %0 : complex<f32>
}
// Test that we rewrite the fir.address_of operator.
// I32-LABEL: func @addrof()
// X64-LABEL: func @addrof()
// AARCH64-LABEL: func @addrof()
// PPC-LABEL: func @addrof()
func @addrof() {
// I32: {{%.*}} = fir.address_of(@returncomplex4) : () -> i64
// X64: {{%.*}} = fir.address_of(@returncomplex4) : () -> !fir.vector<2:!fir.real<4>>
// AARCH64: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple<!fir.real<4>, !fir.real<4>>
// PPC: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple<!fir.real<4>, !fir.real<4>>
%r = fir.address_of(@returncomplex4) : () -> !fir.complex<4>
// I32: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.ref<tuple<!fir.real<4>, !fir.real<4>>>) -> ()
// X64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.vector<2:!fir.real<4>>) -> ()
// AARCH64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.array<2x!fir.real<4>>) -> ()
// PPC: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.real<4>, !fir.real<4>) -> ()
%p = fir.address_of(@paramcomplex4) : (!fir.complex<4>) -> ()
return
}