[mlir] add an interface to support custom types in LLVM dialect pointers

This may be necessary in partial multi-stage conversion when a container type
from dialect A containing types from dialect B goes through the conversion
where only dialect A is converted to the LLVM dialect. We will need to keep a
pointer-to-non-LLVM type in the IR until a further conversion can convert
dialect B types to LLVM types.

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D106076
This commit is contained in:
Alex Zinenko 2021-07-15 18:16:07 +02:00
parent 9769535efd
commit e4b79a542e
11 changed files with 86 additions and 6 deletions

View File

@ -16,7 +16,12 @@ add_public_tablegen_target(MLIRLLVMOpsIncGen)
add_mlir_doc(LLVMOps LLVMOps Dialects/ -gen-op-doc)
add_mlir_interface(LLVMOpsInterfaces)
set(LLVM_TARGET_DEFINITIONS LLVMOpsInterfaces.td)
mlir_tablegen(LLVMOpsInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(LLVMOpsInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(LLVMTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(LLVMTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRLLVMOpsInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)

View File

@ -108,10 +108,18 @@ def LLVM_OpaqueStruct : Type<
And<[LLVM_AnyStruct.predicate,
CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().isOpaque()">]>>;
// Type constraint accepting types that implement that pointer element
// interface.
def LLVM_PointerElementType : Type<
CPred<"$_self.isa<::mlir::LLVM::PointerElementTypeInterface>()">,
"LLVM-compatible pointer element type">;
// Type constraint accepting any LLVM type that can be loaded or stored, i.e. a
// type that has size (not void, function or opaque struct type).
def LLVM_LoadableType : Type<
And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>]>,
Or<[And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>]>,
LLVM_PointerElementType.predicate]>,
"LLVM type with size">;
// Type constraint accepting any LLVM aggregate type, i.e. structure or array.

View File

@ -331,7 +331,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
let results = (outs LLVM_Type:$res);
let results = (outs LLVM_LoadableType:$res);
string llvmBuilder = [{
auto *inst = builder.CreateLoad(
$addr->getType()->getPointerElementType(), $addr, $volatile_);

View File

@ -23,8 +23,38 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "fastmathFlags">,
InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags",
"fastmathFlags">,
];
}
//===----------------------------------------------------------------------===//
// LLVM dialect type interfaces.
//===----------------------------------------------------------------------===//
// An interface for LLVM pointer element types.
def LLVM_PointerElementTypeInterface
: TypeInterface<"PointerElementTypeInterface"> {
let cppNamespace = "::mlir::LLVM";
let description = [{
An interface for types that are allowed as elements of LLVM pointer type.
Such types must have a size.
}];
let methods = [
InterfaceMethod<
/*description=*/"Returns the size of the type in bytes.",
/*retTy=*/"unsigned",
/*methodName=*/"getSizeInBytes",
/*args=*/(ins "const DataLayout &":$dataLayout),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return dataLayout.getTypeSize($_type);
}]
>
];
}
#endif // LLVM_OPS_INTERFACES

View File

@ -36,6 +36,13 @@ struct LLVMPointerTypeStorage;
struct LLVMStructTypeStorage;
struct LLVMTypeAndSizeStorage;
} // namespace detail
} // namespace LLVM
} // namespace mlir
#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.h.inc"
namespace mlir {
namespace LLVM {
//===----------------------------------------------------------------------===//
// Trivial types.

View File

@ -120,8 +120,9 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
//===----------------------------------------------------------------------===//
bool LLVMPointerType::isValidElementType(Type type) {
return !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
LLVMLabelType>();
return isCompatibleType(type) ? !type.isa<LLVMVoidType, LLVMTokenType,
LLVMMetadataType, LLVMLabelType>()
: type.isa<PointerElementTypeInterface>();
}
LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
@ -607,3 +608,5 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
return llvm::TypeSize::Fixed(0);
});
}
#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"

View File

@ -176,6 +176,14 @@ func @verbose() {
return
}
// CHECK-LABEL: @ptr_elem_interface
// CHECK-COUNT-3: !llvm.ptr<!test.smpla>
func @ptr_elem_interface(%arg0: !llvm.ptr<!test.smpla>) {
%0 = llvm.load %arg0 : !llvm.ptr<!test.smpla>
llvm.store %0, %arg0 : !llvm.ptr<!test.smpla>
return
}
// -----
// Check that type aliases can be used inside LLVM dialect types. Note that

View File

@ -62,6 +62,7 @@ add_mlir_library(MLIRTestDialect
MLIRIR
MLIRInferTypeOpInterface
MLIRLinalgTransforms
MLIRLLVMIR
MLIRPass
MLIRReduce
MLIRStandard

View File

@ -13,6 +13,7 @@
#include "TestTypes.h"
#include "TestDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Types.h"
@ -222,11 +223,19 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
// TestDialect
//===----------------------------------------------------------------------===//
namespace {
struct PtrElementModel
: public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
SimpleAType> {};
} // namespace
void TestDialect::registerTypes() {
addTypes<TestRecursiveType,
#define GET_TYPEDEF_LIST
#include "TestTypeDefs.cpp.inc"
>();
SimpleAType::attachInterface<PtrElementModel>(*getContext());
}
static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,

View File

@ -2443,6 +2443,14 @@ gentbl_cc_library(
["-gen-op-interface-defs"],
"include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc",
),
(
["-gen-type-interface-decls"],
"include/mlir/Dialect/LLVMIR/LLVMTypeInterfaces.h.inc",
),
(
["-gen-type-interface-defs"],
"include/mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td",

View File

@ -227,6 +227,7 @@ cc_library(
"//mlir:Dialect",
"//mlir:IR",
"//mlir:InferTypeOpInterface",
"//mlir:LLVMDialect",
"//mlir:Pass",
"//mlir:Reducer",
"//mlir:SideEffects",