mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-14 20:22:30 +00:00
[MLIR][SPIRVToLLVM] Implementation of spv.func conversion, and return ops
This patch provides an implementation for `spv.func` conversion. The pattern is populated in a separate method added to the pass. At the moment, the type signature conversion only includes the supported types. The conversion pattern also matches SPIR-V function control attributes to LLVM function attributes. Those are modelled as `passthrough` attributes in LLVM dialect. The following mapping are used: - None: no attributes passed - Inline: `alwaysinline` seems to be the right equivalent (`inlinehint` is semantically weaker in my opinion) - DontInline: `noinline` - Pure and Const: I think those can be modelled as `readonly` and `readnone` attributes respectively. Also, 2 patterns added for return ops conversion (`spv.Return` for void return and `spv.ReturnValue` for a single value return). Differential Revision: https://reviews.llvm.org/D81931
This commit is contained in:
parent
adf7973fd3
commit
a4dc61344f
@ -37,6 +37,12 @@ void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
|
|||||||
LLVMTypeConverter &typeConverter,
|
LLVMTypeConverter &typeConverter,
|
||||||
OwningRewritePatternList &patterns);
|
OwningRewritePatternList &patterns);
|
||||||
|
|
||||||
|
/// Populates the given list with patterns for function conversion from SPIR-V
|
||||||
|
/// to LLVM.
|
||||||
|
void populateSPIRVToLLVMFunctionConversionPatterns(
|
||||||
|
MLIRContext *context, LLVMTypeConverter &typeConverter,
|
||||||
|
OwningRewritePatternList &patterns);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H
|
#endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H
|
||||||
|
@ -21,6 +21,9 @@
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "spirv-to-llvm-pattern"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@ -150,6 +153,32 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
|
||||||
|
public:
|
||||||
|
using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
|
||||||
|
ArrayRef<Value>());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
|
||||||
|
public:
|
||||||
|
using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
|
||||||
|
operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
|
/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
|
||||||
/// puts a restriction on `Shift` and `Base` to have the same bit width,
|
/// puts a restriction on `Shift` and `Base` to have the same bit width,
|
||||||
/// `Shift` is zero or sign extended to match this specification. Cases when
|
/// `Shift` is zero or sign extended to match this specification. Cases when
|
||||||
@ -191,6 +220,64 @@ public:
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// FuncOp conversion
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
|
||||||
|
public:
|
||||||
|
using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
// Convert function signature. At the moment LLVMType converter is enough
|
||||||
|
// for currently supported types.
|
||||||
|
auto funcType = funcOp.getType();
|
||||||
|
TypeConverter::SignatureConversion signatureConverter(
|
||||||
|
funcType.getNumInputs());
|
||||||
|
auto llvmType = this->typeConverter.convertFunctionSignature(
|
||||||
|
funcOp.getType(), /*isVariadic=*/false, signatureConverter);
|
||||||
|
|
||||||
|
// Create a new `LLVMFuncOp`
|
||||||
|
Location loc = funcOp.getLoc();
|
||||||
|
StringRef name = funcOp.getName();
|
||||||
|
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
|
||||||
|
|
||||||
|
// Convert SPIR-V Function Control to equivalent LLVM function attribute
|
||||||
|
MLIRContext *context = funcOp.getContext();
|
||||||
|
switch (funcOp.function_control()) {
|
||||||
|
#define DISPATCH(functionControl, llvmAttr) \
|
||||||
|
case functionControl: \
|
||||||
|
newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \
|
||||||
|
break;
|
||||||
|
|
||||||
|
DISPATCH(spirv::FunctionControl::Inline,
|
||||||
|
StringAttr::get("alwaysinline", context));
|
||||||
|
DISPATCH(spirv::FunctionControl::DontInline,
|
||||||
|
StringAttr::get("noinline", context));
|
||||||
|
DISPATCH(spirv::FunctionControl::Pure,
|
||||||
|
StringAttr::get("readonly", context));
|
||||||
|
DISPATCH(spirv::FunctionControl::Const,
|
||||||
|
StringAttr::get("readnone", context));
|
||||||
|
|
||||||
|
#undef DISPATCH
|
||||||
|
|
||||||
|
// Default: if `spirv::FunctionControl::None`, then no attributes are
|
||||||
|
// needed.
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||||
|
newFuncOp.end());
|
||||||
|
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
||||||
|
rewriter.eraseOp(funcOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -263,6 +350,14 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
|
|||||||
// Shift ops
|
// Shift ops
|
||||||
ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
|
ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
|
||||||
ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
|
ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
|
||||||
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>>(context,
|
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
|
||||||
typeConverter);
|
|
||||||
|
// Return ops
|
||||||
|
ReturnPattern, ReturnValuePattern>(context, typeConverter);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
|
||||||
|
MLIRContext *context, LLVMTypeConverter &typeConverter,
|
||||||
|
OwningRewritePatternList &patterns) {
|
||||||
|
patterns.insert<FuncConversionPattern>(context, typeConverter);
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
|
|||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
|
populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
|
||||||
|
populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
|
||||||
|
|
||||||
// Currently pulls in Std to LLVM conversion patterns
|
// Currently pulls in Std to LLVM conversion patterns
|
||||||
// that help with testing. This allows to convert
|
// that help with testing. This allows to convert
|
||||||
|
62
mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir
Normal file
62
mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.Return
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @return() {
|
||||||
|
// CHECK: llvm.return
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.ReturnValue
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @return_value(%arg: i32) {
|
||||||
|
// CHECK: llvm.return %{{.*}} : !llvm.i32
|
||||||
|
spv.ReturnValue %arg : i32
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.func
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @none()
|
||||||
|
spv.func @none() -> () "None" {
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @inline() attributes {passthrough = ["alwaysinline"]}
|
||||||
|
spv.func @inline() -> () "Inline" {
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @dont_inline() attributes {passthrough = ["noinline"]}
|
||||||
|
spv.func @dont_inline() -> () "DontInline" {
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @pure() attributes {passthrough = ["readonly"]}
|
||||||
|
spv.func @pure() -> () "Pure" {
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @const() attributes {passthrough = ["readnone"]}
|
||||||
|
spv.func @const() -> () "Const" {
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @scalar_types(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.double, %arg3: !llvm.float)
|
||||||
|
spv.func @scalar_types(%arg0: i32, %arg1: i1, %arg2: f64, %arg3: f32) -> () "None" {
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @vector_types(%arg0: !llvm<"<2 x i64>">, %arg1: !llvm<"<2 x i64>">) -> !llvm<"<2 x i64>">
|
||||||
|
spv.func @vector_types(%arg0: vector<2xi64>, %arg1: vector<2xi64>) -> vector<2xi64> "None" {
|
||||||
|
%0 = spv.IAdd %arg0, %arg1 : vector<2xi64>
|
||||||
|
spv.ReturnValue %0 : vector<2xi64>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user