[flang] Moving common polymorphic code into utility files

Differential revision: https://reviews.llvm.org/D145530
This commit is contained in:
Renaud-K 2023-03-07 16:09:23 -08:00
parent da79d6e177
commit 0538bfe774
6 changed files with 87 additions and 53 deletions

View File

@ -261,6 +261,24 @@ inline fir::SequenceType unwrapUntilSeqType(mlir::Type t) {
}
}
/// Unwrap the referential and sequential outer types (if any). Returns the
/// the element if type is fir::RecordType
inline fir::RecordType unwrapIfDerived(fir::BaseBoxType boxTy) {
return fir::unwrapSequenceType(fir::unwrapRefType(boxTy.getEleTy()))
.template dyn_cast<fir::RecordType>();
}
/// Return true iff `boxTy` wraps a fir::RecordType with length parameters
inline bool isDerivedTypeWithLenParams(fir::BaseBoxType boxTy) {
auto recTy = unwrapIfDerived(boxTy);
return recTy && recTy.getNumLenParams() > 0;
}
/// Return true iff `boxTy` wraps a fir::RecordType
inline bool isDerivedType(fir::BaseBoxType boxTy) {
return static_cast<bool>(unwrapIfDerived(boxTy));
}
#ifndef NDEBUG
// !fir.ptr<X> and !fir.heap<X> where X is !fir.ptr, !fir.heap, or !fir.ref
// is undefined and disallowed.
@ -300,6 +318,13 @@ bool isPolymorphicType(mlir::Type ty);
/// value.
bool isUnlimitedPolymorphicType(mlir::Type ty);
/// Return true iff `boxTy` wraps a record type or an unlimited polymorphic
/// entity. Polymorphic entities with intrinsic type spec do not have addendum
inline bool boxHasAddendum(fir::BaseBoxType boxTy) {
return static_cast<bool>(unwrapIfDerived(boxTy)) ||
fir::isUnlimitedPolymorphicType(boxTy);
}
/// Return the inner type of the given type.
mlir::Type unwrapInnerType(mlir::Type ty);

View File

@ -15,8 +15,12 @@
#include "flang/Common/default-kinds.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
namespace fir {
/// Return the integer value of a arith::ConstantOp.
@ -24,6 +28,11 @@ inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
return cop.getValue().cast<mlir::IntegerAttr>().getValue().getSExtValue();
}
// Reconstruct binding tables for dynamic dispatch.
using BindingTable = llvm::DenseMap<llvm::StringRef, unsigned>;
using BindingTables = llvm::DenseMap<llvm::StringRef, BindingTable>;
void buildBindingTables(BindingTables &, mlir::ModuleOp mod);
// Translate front-end KINDs for use in the IR and code gen.
inline std::vector<fir::KindTy>
fromDefaultKinds(const Fortran::common::IntrinsicTypeDefaultKinds &defKinds) {

View File

@ -16,8 +16,10 @@
#include "flang/ISO_Fortran_binding.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Support/TypeCode.h"
#include "flang/Optimizer/Support/Utils.h"
#include "flang/Semantics/runtime-type-info.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
@ -50,9 +52,6 @@ namespace fir {
// fir::LLVMTypeConverter for converting to LLVM IR dialect types.
#include "TypeConverter.h"
using BindingTable = llvm::DenseMap<llvm::StringRef, unsigned>;
using BindingTables = llvm::DenseMap<llvm::StringRef, BindingTable>;
// TODO: This should really be recovered from the specified target.
static constexpr unsigned defaultAlign = 8;
@ -106,7 +105,7 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
public:
explicit FIROpConversion(fir::LLVMTypeConverter &lowering,
const fir::FIRToLLVMPassOptions &options,
const BindingTables &bindingTables)
const fir::BindingTables &bindingTables)
: mlir::ConvertOpToLLVMPattern<FromOp>(lowering), options(options),
bindingTables(bindingTables) {}
@ -359,7 +358,7 @@ protected:
}
const fir::FIRToLLVMPassOptions &options;
const BindingTables &bindingTables;
const fir::BindingTables &bindingTables;
};
/// FIR conversion pattern template
@ -993,7 +992,7 @@ struct DispatchOpConversion : public FIROpConversion<fir::DispatchOp> {
<< "cannot find binding table for " << recordType.getName();
// Lookup for the binding.
const BindingTable &bindingTable = bindingsIter->second;
const fir::BindingTable &bindingTable = bindingsIter->second;
auto bindingIter = bindingTable.find(dispatch.getMethod());
if (bindingIter == bindingTable.end())
return emitError(loc)
@ -1336,22 +1335,6 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
return CFI_attribute_other;
}
static fir::RecordType unwrapIfDerived(fir::BaseBoxType boxTy) {
return fir::unwrapSequenceType(fir::dyn_cast_ptrOrBoxEleTy(boxTy))
.template dyn_cast<fir::RecordType>();
}
static bool isDerivedTypeWithLenParams(fir::BaseBoxType boxTy) {
auto recTy = unwrapIfDerived(boxTy);
return recTy && recTy.getNumLenParams() > 0;
}
static bool isDerivedType(fir::BaseBoxType boxTy) {
return static_cast<bool>(unwrapIfDerived(boxTy));
}
static bool hasAddendum(fir::BaseBoxType boxTy) {
return static_cast<bool>(unwrapIfDerived(boxTy)) ||
fir::isUnlimitedPolymorphicType(boxTy);
}
// Get the element size and CFI type code of the boxed value.
std::tuple<mlir::Value, mlir::Value> getSizeAndTypeCode(
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
@ -1571,7 +1554,7 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
descriptor =
insertField(rewriter, loc, descriptor, {kAttributePosInBox},
this->genI32Constant(loc, rewriter, getCFIAttr(boxTy)));
const bool hasAddendum = isDerivedType(boxTy) || isUnlimitedPolymorphic;
const bool hasAddendum = fir::boxHasAddendum(boxTy);
descriptor =
insertField(rewriter, loc, descriptor, {kF18AddendumPosInBox},
this->genI32Constant(loc, rewriter, hasAddendum ? 1 : 0));
@ -1591,8 +1574,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
loc, ::getVoidPtrType(mod.getContext()));
}
} else {
typeDesc =
getTypeDescriptor(mod, rewriter, loc, unwrapIfDerived(boxTy));
typeDesc = getTypeDescriptor(mod, rewriter, loc,
fir::unwrapIfDerived(boxTy));
}
}
if (typeDesc)
@ -1674,7 +1657,7 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
// TODO: For initial box that are unlimited polymorphic entities, this
// code must be made conditional because unlimited polymorphic entities
// with intrinsic type spec does not have addendum.
if (hasAddendum(inputBoxTy))
if (fir::boxHasAddendum(inputBoxTy))
typeDesc = this->loadTypeDescAddress(loc, box.getBox().getType(),
loweredBox, rewriter);
}
@ -1826,7 +1809,7 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
/*rank=*/0, /*lenParams=*/operands.drop_front(1), sourceBox,
sourceBoxType);
dest = insertBaseAddress(rewriter, embox.getLoc(), dest, operands[0]);
if (isDerivedTypeWithLenParams(boxTy)) {
if (fir::isDerivedTypeWithLenParams(boxTy)) {
TODO(embox.getLoc(),
"fir.embox codegen of derived with length parameters");
return mlir::failure();
@ -2010,7 +1993,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
fieldIndices, substringOffset);
}
dest = insertBaseAddress(rewriter, loc, dest, base);
if (isDerivedTypeWithLenParams(boxTy))
if (fir::isDerivedTypeWithLenParams(boxTy))
TODO(loc, "fir.embox codegen of derived with length parameters");
mlir::Value result =
@ -3670,7 +3653,7 @@ template <typename FromOp>
struct MustBeDeadConversion : public FIROpConversion<FromOp> {
explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering,
const fir::FIRToLLVMPassOptions &options,
const BindingTables &bindingTables)
const fir::BindingTables &bindingTables)
: FIROpConversion<FromOp>(lowering, options, bindingTables) {}
using OpAdaptor = typename FromOp::Adaptor;
@ -3781,24 +3764,8 @@ public:
if (mlir::failed(runPipeline(mathConvertionPM, mod)))
return signalPassFailure();
// Reconstruct binding tables for dynamic dispatch. The binding tables
// are defined in FIR from lowering as fir.dispatch_table operation.
// Go through each binding tables and store the procedure name
// and binding index for later use by the fir.dispatch conversion pattern.
BindingTables bindingTables;
for (auto dispatchTableOp : mod.getOps<fir::DispatchTableOp>()) {
unsigned bindingIdx = 0;
BindingTable bindings;
if (dispatchTableOp.getRegion().empty()) {
bindingTables[dispatchTableOp.getSymName()] = bindings;
continue;
}
for (auto dtEntry : dispatchTableOp.getBlock().getOps<fir::DTEntryOp>()) {
bindings[dtEntry.getMethod()] = bindingIdx;
++bindingIdx;
}
bindingTables[dispatchTableOp.getSymName()] = bindings;
}
fir::BindingTables bindingTables;
fir::buildBindingTables(bindingTables, mod);
auto *context = getModule().getContext();
fir::LLVMTypeConverter typeConverter{getModule(),

View File

@ -219,12 +219,8 @@ mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
.Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
.Case<fir::BaseBoxType>([](auto p) {
auto eleTy = p.getEleTy();
if (auto ty = fir::dyn_cast_ptrEleTy(eleTy))
return ty;
return eleTy;
})
.Case<fir::BaseBoxType>(
[](auto p) { return unwrapRefType(p.getEleTy()); })
.Default([](mlir::Type) { return mlir::Type{}; });
}

View File

@ -5,6 +5,7 @@ add_flang_library(FIRSupport
InitFIR.cpp
InternalNames.cpp
KindMapping.cpp
Utils.cpp
DEPENDS
FIROpsIncGen

View File

@ -0,0 +1,36 @@
//===-- Utils.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Support/Utils.h"
#include "flang/Optimizer/Dialect/FIROps.h"
namespace fir {
void buildBindingTables(BindingTables &bindingTables, mlir::ModuleOp mod) {
// The binding tables are defined in FIR from lowering as fir.dispatch_table
// operation. Go through each binding tables and store the procedure name and
// binding index for later use by the fir.dispatch conversion pattern.
for (auto dispatchTableOp : mod.getOps<fir::DispatchTableOp>()) {
unsigned bindingIdx = 0;
BindingTable bindings;
if (dispatchTableOp.getRegion().empty()) {
bindingTables[dispatchTableOp.getSymName()] = bindings;
continue;
}
for (auto dtEntry : dispatchTableOp.getBlock().getOps<fir::DTEntryOp>()) {
bindings[dtEntry.getMethod()] = bindingIdx;
++bindingIdx;
}
bindingTables[dispatchTableOp.getSymName()] = bindings;
}
}
} // namespace fir