[Flang] Add the FIR LLVMPointer Type

Add a fir.llvm_ptr type to allow any level of indirections

Currently, fir pointer types (fir.ref, fir.ptr, and fir.heap) carry
a special Fortran semantics, and cannot be freely combined/nested.

When implementing some features, lowering sometimes needs more liberty
regarding the number of indirection levels. Add a fir.llvm_ptr that has
no constraints.

Allow its usage in fir.coordinate_op, fir.load, and fir.store.

Convert the FIR LLVMPointer to an LLVMPointer in the LLVM dialect.

Reviewed By: clementval

Differential Revision: https://reviews.llvm.org/D113755

Co-authored-by: Jean Perier <jperier@nvidia.com>
This commit is contained in:
Kiran Chandramohan 2021-11-13 18:02:16 +00:00
parent 0e738323a9
commit 49c08a22ed
9 changed files with 77 additions and 21 deletions

View File

@ -1661,7 +1661,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
TypeAttr:$baseType
);
let results = (outs fir_ReferenceType);
let results = (outs RefOrLLVMPtr);
let parser = "return parseCoordinateCustom(parser, result);";
let printer = "::print(p, *this);";

View File

@ -59,7 +59,8 @@ bool isa_fir_or_std_type(mlir::Type t);
/// Is `t` a FIR dialect type that implies a memory (de)reference?
inline bool isa_ref_type(mlir::Type t) {
return t.isa<ReferenceType>() || t.isa<PointerType>() || t.isa<HeapType>();
return t.isa<ReferenceType>() || t.isa<PointerType>() || t.isa<HeapType>() ||
t.isa<fir::LLVMPointerType>();
}
/// Is `t` a boxed type?

View File

@ -224,6 +224,22 @@ def fir_LogicalType : FIR_Type<"Logical", "logical"> {
}];
}
def fir_LLVMPointerType : FIR_Type<"LLVMPointer", "llvm_ptr"> {
let summary = "Like LLVM pointer type";
let description = [{
A pointer type that does not have any of the constraints and semantics
of other FIR pointer types and that translates to llvm pointer types.
It is meant to implement indirection that cannot be expressed directly
in Fortran, but are needed to implement some Fortran features (e.g,
double indirections).
}];
let parameters = (ins "mlir::Type":$eleTy);
let assemblyFormat = "`<` $eleTy `>`";
}
def fir_PointerType : FIR_Type<"Pointer", "ptr"> {
let summary = "Reference to a POINTER attribute type";
@ -516,7 +532,11 @@ def AnyCompositeLike : TypeConstraint<Or<[fir_RecordType.predicate,
// Reference types
def AnyReferenceLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
fir_HeapType.predicate, fir_PointerType.predicate]>, "any reference">;
fir_HeapType.predicate, fir_PointerType.predicate,
fir_LLVMPointerType.predicate]>, "any reference">;
def RefOrLLVMPtr : TypeConstraint<Or<[fir_ReferenceType.predicate,
fir_LLVMPointerType.predicate]>, "fir.ref or fir.llvm_ptr">;
def AnyBoxLike : TypeConstraint<Or<[fir_BoxType.predicate,
fir_BoxCharType.predicate, fir_BoxProcType.predicate]>, "any box">;

View File

@ -65,6 +65,9 @@ public:
return mlir::IntegerType::get(
&getContext(), kindMapping.getLogicalBitsize(boolTy.getFKind()));
});
addConversion([&](fir::LLVMPointerType pointer) {
return convertPointerLike(pointer);
});
addConversion(
[&](fir::PointerType pointer) { return convertPointerLike(pointer); });
addConversion(

View File

@ -824,8 +824,9 @@ bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) {
bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) {
return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() ||
ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() ||
ty.isa<mlir::FunctionType>() || ty.isa<fir::TypeDescType>();
ty.isa<fir::HeapType>() || ty.isa<fir::LLVMPointerType>() ||
ty.isa<mlir::MemRefType>() || ty.isa<mlir::FunctionType>() ||
ty.isa<fir::TypeDescType>();
}
static mlir::LogicalResult verify(fir::ConvertOp &op) {
@ -1755,16 +1756,8 @@ void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
result.addTypes(eleTy);
}
/// Get the element type of a reference like type; otherwise null
static mlir::Type elementTypeOf(mlir::Type ref) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref)
.Case<ReferenceType, PointerType, HeapType>(
[](auto type) { return type.getEleTy(); })
.Default([](mlir::Type) { return mlir::Type{}; });
}
mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) {
if ((ele = elementTypeOf(ref)))
if ((ele = fir::dyn_cast_ptrEleTy(ref)))
return mlir::success();
return mlir::failure();
}

View File

@ -200,15 +200,15 @@ bool isa_fir_or_std_type(mlir::Type t) {
mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
.Case<fir::ReferenceType, fir::PointerType, fir::HeapType>(
[](auto p) { return p.getEleTy(); })
.Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
.Default([](mlir::Type) { return mlir::Type{}; });
}
mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
.Case<fir::ReferenceType, fir::PointerType, fir::HeapType>(
[](auto p) { return p.getEleTy(); })
.Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
.Case<fir::BoxType>([](auto p) {
auto eleTy = p.getEleTy();
if (auto ty = fir::dyn_cast_ptrEleTy(eleTy))
@ -864,7 +864,7 @@ bool fir::VectorType::isValidElementType(mlir::Type t) {
void FIROpsDialect::registerTypes() {
addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
PointerType, RealType, RecordType, ReferenceType, SequenceType,
ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
fir::VectorType>();
LLVMPointerType, PointerType, RealType, RecordType, ReferenceType,
SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType,
TypeDescType, fir::VectorType>();
}

View File

@ -707,3 +707,18 @@ func @slice_substr() {
// CHECK: fir.slice %{{.*}}, %{{.*}}, %{{.*}} substr %{{.*}}, %{{.*}} : (index, index, index, index, index) -> !fir.slice<1>
return
}
// Test load, store, coordinate_of with llvmptr type
// CHECK-LABEL: llvm_ptr_load_store_coordinate
// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<tuple<!fir.ref<!fir.box<!fir.ptr<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>>>, %[[ARG1:.*]]: !fir.ref<!fir.box<!fir.ptr<f32>>>)
func @llvm_ptr_load_store_coordinate(%arg0: !fir.ref<tuple<!fir.ref<!fir.box<!fir.ptr<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>>>, %arg1: !fir.ref<!fir.box<!fir.ptr<f32>>>) -> !fir.ref<!fir.box<!fir.ptr<f32>>> {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : i32
%c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %[[LLVMPTR:.*]] = fir.coordinate_of %[[ARG0]], %[[C0]] : (!fir.ref<tuple<!fir.ref<!fir.box<!fir.ptr<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>>>, i32) -> !fir.llvm_ptr<!fir.ref<!fir.box<!fir.ptr<f32>>>>
%0 = fir.coordinate_of %arg0, %c0_i32 : (!fir.ref<tuple<!fir.ref<!fir.box<!fir.ptr<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>>>, i32) -> !fir.llvm_ptr<!fir.ref<!fir.box<!fir.ptr<f32>>>>
// CHECK-NEXT: fir.store %[[ARG1]] to %[[LLVMPTR]] : !fir.llvm_ptr<!fir.ref<!fir.box<!fir.ptr<f32>>>>
fir.store %arg1 to %0 : !fir.llvm_ptr<!fir.ref<!fir.box<!fir.ptr<f32>>>>
// CHECK-NEXT: fir.load %[[LLVMPTR]] : !fir.llvm_ptr<!fir.ref<!fir.box<!fir.ptr<f32>>>>
%1 = fir.load %0 : !fir.llvm_ptr<!fir.ref<!fir.box<!fir.ptr<f32>>>>
return %1 : !fir.ref<!fir.box<!fir.ptr<f32>>>
}

View File

@ -55,10 +55,14 @@ func private @arr7() -> !fir.array<1x2x?x4x5x6x7x8x9xf32>
// CHECK-LABEL: func private @mem2() -> !fir.ptr<i32>
// CHECK-LABEL: func private @mem3() -> !fir.heap<i32>
// CHECK-LABEL: func private @mem4() -> !fir.ref<() -> ()>
// CHECK-LABEL: func private @mem5() -> !fir.llvm_ptr<!fir.ref<f32>>
// CHECK-LABEL: func private @mem6() -> !fir.llvm_ptr<i8>
func private @mem1() -> !fir.ref<i32>
func private @mem2() -> !fir.ptr<i32>
func private @mem3() -> !fir.heap<i32>
func private @mem4() -> !fir.ref<() -> ()>
func private @mem5() -> !fir.llvm_ptr<!fir.ref<f32>>
func private @mem6() -> !fir.llvm_ptr<i8>
// FIR box types (descriptors)
// CHECK-LABEL: func private @box1() -> !fir.box<!fir.array<?xf32>>

View File

@ -176,6 +176,26 @@ func private @foo4(%arg0: !fir.logical<16>)
// -----
// Test `!fir.llvm_ptr` conversion.
func private @foo0(%arg0: !fir.llvm_ptr<i8>)
// CHECK-LABEL: foo0
// CHECK-SAME: !llvm.ptr<i8>
func private @foo1(%arg0: !fir.llvm_ptr<!fir.ref<f32>>)
// CHECK-LABEL: foo1
// CHECK-SAME: !llvm.ptr<ptr<f32>>
func private @foo2(%arg0: !fir.llvm_ptr<!fir.ref<!fir.box<!fir.ptr<i32>>>>)
// CHECK-LABEL: foo2
// CHECK-SAME: !llvm.ptr<ptr<struct<(ptr<i32>, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}})>>>
func private @foo3(%arg0: !fir.llvm_ptr<!fir.ptr<f32>>)
// CHECK-LABEL: foo3
// CHECK-SAME: !llvm.ptr<ptr<f32>>
// -----
// Test `!fir.complex<KIND>` conversion.
func private @foo0(%arg0: !fir.complex<2>)