mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-12 21:28:48 +00:00
[mlir][ub] Add poison support to CommonFolders.h
Return poison from foldBinary/unary if argument(s) is poison. Add ub dialect as dependency to affected dialects (arith, math, spirv, shape). Add poison materialization to dialects. Add tests for some ops from each dialect. Not all affected ops are covered as it will involve a huge copypaste. Differential Revision: https://reviews.llvm.org/D159013
This commit is contained in:
parent
0dd4d3b5cc
commit
5dce74817b
@ -22,17 +22,35 @@
|
||||
#include <optional>
|
||||
|
||||
namespace mlir {
|
||||
namespace ub {
|
||||
class PoisonAttr;
|
||||
}
|
||||
/// Performs constant folding `calculate` with element-wise behavior on the two
|
||||
/// attributes in `operands` and returns the result if possible.
|
||||
/// Uses `resultType` for the type of the returned attribute.
|
||||
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
|
||||
/// which will be directly propagated to result.
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class PoisonAttr = ub::PoisonAttr,
|
||||
class CalculationT = function_ref<
|
||||
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
|
||||
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
|
||||
Type resultType,
|
||||
const CalculationT &calculate) {
|
||||
CalculationT &&calculate) {
|
||||
assert(operands.size() == 2 && "binary op takes two operands");
|
||||
static_assert(
|
||||
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
|
||||
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
|
||||
"void as template argument to opt-out from poison semantics.");
|
||||
if constexpr (!std::is_void_v<PoisonAttr>) {
|
||||
if (isa_and_nonnull<PoisonAttr>(operands[0]))
|
||||
return operands[0];
|
||||
|
||||
if (isa_and_nonnull<PoisonAttr>(operands[1]))
|
||||
return operands[1];
|
||||
}
|
||||
|
||||
if (!resultType || !operands[0] || !operands[1])
|
||||
return {};
|
||||
|
||||
@ -95,13 +113,28 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
|
||||
/// attributes in `operands` and returns the result if possible.
|
||||
/// Uses the operand element type for the element type of the returned
|
||||
/// attribute.
|
||||
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
|
||||
/// which will be directly propagated to result.
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class PoisonAttr = ub::PoisonAttr,
|
||||
class CalculationT = function_ref<
|
||||
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
|
||||
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
|
||||
const CalculationT &calculate) {
|
||||
CalculationT &&calculate) {
|
||||
assert(operands.size() == 2 && "binary op takes two operands");
|
||||
static_assert(
|
||||
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
|
||||
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
|
||||
"void as template argument to opt-out from poison semantics.");
|
||||
if constexpr (!std::is_void_v<PoisonAttr>) {
|
||||
if (isa_and_nonnull<PoisonAttr>(operands[0]))
|
||||
return operands[0];
|
||||
|
||||
if (isa_and_nonnull<PoisonAttr>(operands[1]))
|
||||
return operands[1];
|
||||
}
|
||||
|
||||
auto getResultType = [](Attribute attr) -> Type {
|
||||
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
|
||||
return typed.getType();
|
||||
@ -115,18 +148,19 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
|
||||
if (lhsType != rhsType)
|
||||
return {};
|
||||
|
||||
return constFoldBinaryOpConditional<AttrElementT, ElementValueT,
|
||||
CalculationT>(operands, lhsType,
|
||||
calculate);
|
||||
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
|
||||
CalculationT>(
|
||||
operands, lhsType, std::forward<CalculationT>(calculate));
|
||||
}
|
||||
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class PoisonAttr = void,
|
||||
class CalculationT =
|
||||
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
|
||||
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
|
||||
const CalculationT &calculate) {
|
||||
return constFoldBinaryOpConditional<AttrElementT>(
|
||||
CalculationT &&calculate) {
|
||||
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
|
||||
operands, resultType,
|
||||
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
|
||||
return calculate(a, b);
|
||||
@ -135,11 +169,12 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
|
||||
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class PoisonAttr = ub::PoisonAttr,
|
||||
class CalculationT =
|
||||
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
|
||||
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
|
||||
const CalculationT &calculate) {
|
||||
return constFoldBinaryOpConditional<AttrElementT>(
|
||||
CalculationT &&calculate) {
|
||||
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
|
||||
operands,
|
||||
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
|
||||
return calculate(a, b);
|
||||
@ -148,16 +183,28 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
|
||||
|
||||
/// Performs constant folding `calculate` with element-wise behavior on the one
|
||||
/// attributes in `operands` and returns the result if possible.
|
||||
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
|
||||
/// which will be directly propagated to result.
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class PoisonAttr = ub::PoisonAttr,
|
||||
class CalculationT =
|
||||
function_ref<std::optional<ElementValueT>(ElementValueT)>>
|
||||
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
|
||||
const CalculationT &&calculate) {
|
||||
CalculationT &&calculate) {
|
||||
assert(operands.size() == 1 && "unary op takes one operands");
|
||||
if (!operands[0])
|
||||
return {};
|
||||
|
||||
static_assert(
|
||||
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
|
||||
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
|
||||
"void as template argument to opt-out from poison semantics.");
|
||||
if constexpr (!std::is_void_v<PoisonAttr>) {
|
||||
if (isa<PoisonAttr>(operands[0]))
|
||||
return operands[0];
|
||||
}
|
||||
|
||||
if (isa<AttrElementT>(operands[0])) {
|
||||
auto op = cast<AttrElementT>(operands[0]);
|
||||
|
||||
@ -196,10 +243,11 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
|
||||
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class PoisonAttr = ub::PoisonAttr,
|
||||
class CalculationT = function_ref<ElementValueT(ElementValueT)>>
|
||||
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
|
||||
const CalculationT &&calculate) {
|
||||
return constFoldUnaryOpConditional<AttrElementT>(
|
||||
CalculationT &&calculate) {
|
||||
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
|
||||
operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
|
||||
return calculate(a);
|
||||
});
|
||||
@ -209,13 +257,23 @@ template <
|
||||
class AttrElementT, class TargetAttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class TargetElementValueT = typename TargetAttrElementT::ValueType,
|
||||
class PoisonAttr = ub::PoisonAttr,
|
||||
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
|
||||
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
|
||||
const CalculationT &calculate) {
|
||||
CalculationT &&calculate) {
|
||||
assert(operands.size() == 1 && "Cast op takes one operand");
|
||||
if (!operands[0])
|
||||
return {};
|
||||
|
||||
static_assert(
|
||||
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
|
||||
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
|
||||
"void as template argument to opt-out from poison semantics.");
|
||||
if constexpr (!std::is_void_v<PoisonAttr>) {
|
||||
if (isa<PoisonAttr>(operands[0]))
|
||||
return operands[0];
|
||||
}
|
||||
|
||||
if (isa<AttrElementT>(operands[0])) {
|
||||
auto op = cast<AttrElementT>(operands[0]);
|
||||
bool castStatus = true;
|
||||
@ -254,7 +312,6 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_COMMONFOLDERS_H
|
||||
|
@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
@ -49,5 +50,8 @@ void arith::ArithDialect::initialize() {
|
||||
Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
|
||||
return builder.create<ub::PoisonOp>(loc, type, poison);
|
||||
|
||||
return ConstantOp::materialize(builder, value, type, loc);
|
||||
}
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRMathDialect
|
||||
MLIRArithDialect
|
||||
MLIRDialect
|
||||
MLIRIR
|
||||
MLIRUBDialect
|
||||
)
|
||||
|
@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include <optional>
|
||||
|
||||
@ -522,5 +523,8 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
|
||||
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
|
||||
return builder.create<ub::PoisonOp>(loc, type, poison);
|
||||
|
||||
return arith::ConstantOp::materialize(builder, value, type, loc);
|
||||
}
|
||||
|
@ -43,4 +43,5 @@ add_mlir_dialect_library(MLIRSPIRVDialect
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
MLIRTransforms
|
||||
MLIRUBDialect
|
||||
)
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
@ -949,6 +950,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
|
||||
return builder.create<ub::PoisonOp>(loc, type, poison);
|
||||
|
||||
if (!spirv::ConstantOp::isBuildableWith(type))
|
||||
return nullptr;
|
||||
|
||||
|
@ -23,4 +23,5 @@ add_mlir_dialect_library(MLIRShapeDialect
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRTensorDialect
|
||||
MLIRUBDialect
|
||||
)
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
@ -147,6 +148,9 @@ void ShapeDialect::initialize() {
|
||||
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
|
||||
return builder.create<ub::PoisonOp>(loc, type, poison);
|
||||
|
||||
if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
|
||||
return builder.create<ConstShapeOp>(
|
||||
loc, type, llvm::cast<DenseIntElementsAttr>(value));
|
||||
@ -156,6 +160,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
|
||||
if (llvm::isa<WitnessType>(type))
|
||||
return builder.create<ConstWitnessOp>(loc, type,
|
||||
llvm::cast<BoolAttr>(value));
|
||||
|
||||
return arith::ConstantOp::materialize(builder, value, type, loc);
|
||||
}
|
||||
|
||||
|
@ -2584,3 +2584,58 @@ func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) {
|
||||
%select4 = arith.select %false, %poison, %arg : i32
|
||||
return %select1, %select2, %select3, %select4 : i32, i32, i32, i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @addi_poison1
|
||||
// CHECK: %[[P:.*]] = ub.poison : i32
|
||||
// CHECK: return %[[P]]
|
||||
func.func @addi_poison1(%arg: i32) -> i32 {
|
||||
%0 = ub.poison : i32
|
||||
%1 = arith.addi %0, %arg : i32
|
||||
return %1 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @addi_poison2
|
||||
// CHECK: %[[P:.*]] = ub.poison : i32
|
||||
// CHECK: return %[[P]]
|
||||
func.func @addi_poison2(%arg: i32) -> i32 {
|
||||
%0 = ub.poison : i32
|
||||
%1 = arith.addi %arg, %0 : i32
|
||||
return %1 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @addf_poison1
|
||||
// CHECK: %[[P:.*]] = ub.poison : f32
|
||||
// CHECK: return %[[P]]
|
||||
func.func @addf_poison1(%arg: f32) -> f32 {
|
||||
%0 = ub.poison : f32
|
||||
%1 = arith.addf %0, %arg : f32
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @addf_poison2
|
||||
// CHECK: %[[P:.*]] = ub.poison : f32
|
||||
// CHECK: return %[[P]]
|
||||
func.func @addf_poison2(%arg: f32) -> f32 {
|
||||
%0 = ub.poison : f32
|
||||
%1 = arith.addf %arg, %0 : f32
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @negf_poison
|
||||
// CHECK: %[[P:.*]] = ub.poison : f32
|
||||
// CHECK: return %[[P]]
|
||||
func.func @negf_poison() -> f32 {
|
||||
%0 = ub.poison : f32
|
||||
%1 = arith.negf %0 : f32
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extsi_poison
|
||||
// CHECK: %[[P:.*]] = ub.poison : i64
|
||||
// CHECK: return %[[P]]
|
||||
func.func @extsi_poison() -> i64 {
|
||||
%0 = ub.poison : i32
|
||||
%1 = arith.extsi %0 : i32 to i64
|
||||
return %1 : i64
|
||||
}
|
||||
|
@ -483,3 +483,12 @@ func.func @erf_fold_vec() -> (vector<4xf32>) {
|
||||
%0 = math.erf %v1 : vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @abs_poison
|
||||
// CHECK: %[[P:.*]] = ub.poison : f32
|
||||
// CHECK: return %[[P]]
|
||||
func.func @abs_poison() -> f32 {
|
||||
%0 = ub.poison : f32
|
||||
%1 = math.absf %0 : f32
|
||||
return %1 : f32
|
||||
}
|
||||
|
@ -325,6 +325,15 @@ func.func @const_fold_vector_iadd() -> vector<3xi32> {
|
||||
return %0: vector<3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @iadd_poison
|
||||
// CHECK: %[[P:.*]] = ub.poison : i32
|
||||
// CHECK: return %[[P]]
|
||||
func.func @iadd_poison(%arg0: i32) -> i32 {
|
||||
%0 = ub.poison : i32
|
||||
%1 = spirv.IAdd %arg0, %0 : i32
|
||||
return %1: i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1479,3 +1479,16 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
|
||||
// CHECK: return %[[DIM]]
|
||||
return %result : index
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @add_poison
|
||||
// CHECK: %[[P:.*]] = ub.poison : !shape.siz
|
||||
// CHECK: return %[[P]]
|
||||
func.func @add_poison() -> !shape.size {
|
||||
%1 = shape.const_size 2
|
||||
%2 = ub.poison : !shape.size
|
||||
%result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user