From 5dce74817b71a1f646fb2857c037b3a66f41c7cd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 28 Aug 2023 20:51:37 +0200 Subject: [PATCH] [mlir][ub] Add poison support to CommonFolders.h Return poison from foldBinary/unary if argument(s) is poison. Add ub dialect as dependency to affected dialects (arith, math, spirv, shape). Add poison materialization to dialects. Add tests for some ops from each dialect. Not all affected ops are covered as it will involve a huge copypaste. Differential Revision: https://reviews.llvm.org/D159013 --- mlir/include/mlir/Dialect/CommonFolders.h | 85 ++++++++++++++++--- mlir/lib/Dialect/Arith/IR/ArithDialect.cpp | 4 + .../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 1 - mlir/lib/Dialect/Math/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/Math/IR/MathDialect.cpp | 1 + mlir/lib/Dialect/Math/IR/MathOps.cpp | 4 + mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt | 1 + .../SPIRV/IR/SPIRVCanonicalization.cpp | 1 + mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 4 + mlir/lib/Dialect/Shape/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/Shape/IR/Shape.cpp | 5 ++ mlir/test/Dialect/Arith/canonicalize.mlir | 55 ++++++++++++ mlir/test/Dialect/Math/canonicalize.mlir | 9 ++ .../SPIRV/Transforms/canonicalize.mlir | 9 ++ mlir/test/Dialect/Shape/canonicalize.mlir | 13 +++ 15 files changed, 179 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h index d3fbc723632a..6257e4a60188 100644 --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -22,17 +22,35 @@ #include namespace mlir { +namespace ub { +class PoisonAttr; +} /// Performs constant folding `calculate` with element-wise behavior on the two /// attributes in `operands` and returns the result if possible. /// Uses `resultType` for the type of the returned attribute. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be directly propagated to result. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, Type resultType, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + } + if (!resultType || !operands[0] || !operands[1]) return {}; @@ -95,13 +113,28 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, /// attributes in `operands` and returns the result if possible. /// Uses the operand element type for the element type of the returned /// attribute. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be directly propagated to result. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + } + auto getResultType = [](Attribute attr) -> Type { if (auto typed = dyn_cast_or_null(attr)) return typed.getType(); @@ -115,18 +148,19 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, if (lhsType != rhsType) return {}; - return constFoldBinaryOpConditional(operands, lhsType, - calculate); + return constFoldBinaryOpConditional( + operands, lhsType, std::forward(calculate)); } template > Attribute constFoldBinaryOp(ArrayRef operands, Type resultType, - const CalculationT &calculate) { - return constFoldBinaryOpConditional( + CalculationT &&calculate) { + return constFoldBinaryOpConditional( operands, resultType, [&](ElementValueT a, ElementValueT b) -> std::optional { return calculate(a, b); @@ -135,11 +169,12 @@ Attribute constFoldBinaryOp(ArrayRef operands, Type resultType, template > Attribute constFoldBinaryOp(ArrayRef operands, - const CalculationT &calculate) { - return constFoldBinaryOpConditional( + CalculationT &&calculate) { + return constFoldBinaryOpConditional( operands, [&](ElementValueT a, ElementValueT b) -> std::optional { return calculate(a, b); @@ -148,16 +183,28 @@ Attribute constFoldBinaryOp(ArrayRef operands, /// Performs constant folding `calculate` with element-wise behavior on the one /// attributes in `operands` and returns the result if possible. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be directly propagated to result. template (ElementValueT)>> Attribute constFoldUnaryOpConditional(ArrayRef operands, - const CalculationT &&calculate) { + CalculationT &&calculate) { assert(operands.size() == 1 && "unary op takes one operands"); if (!operands[0]) return {}; + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + if (isa(operands[0])) { auto op = cast(operands[0]); @@ -196,10 +243,11 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, template > Attribute constFoldUnaryOp(ArrayRef operands, - const CalculationT &&calculate) { - return constFoldUnaryOpConditional( + CalculationT &&calculate) { + return constFoldUnaryOpConditional( operands, [&](ElementValueT a) -> std::optional { return calculate(a); }); @@ -209,13 +257,23 @@ template < class AttrElementT, class TargetAttrElementT, class ElementValueT = typename AttrElementT::ValueType, class TargetElementValueT = typename TargetAttrElementT::ValueType, + class PoisonAttr = ub::PoisonAttr, class CalculationT = function_ref> Attribute constFoldCastOp(ArrayRef operands, Type resType, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 1 && "Cast op takes one operand"); if (!operands[0]) return {}; + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + if (isa(operands[0])) { auto op = cast(operands[0]); bool castStatus = true; @@ -254,7 +312,6 @@ Attribute constFoldCastOp(ArrayRef operands, Type resType, } return {}; } - } // namespace mlir #endif // MLIR_DIALECT_COMMONFOLDERS_H diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp index 7f2d79355fe0..ed4b91cbe516 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" @@ -49,5 +50,8 @@ void arith::ArithDialect::initialize() { Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index 0a86d8f15b0d..fab6f3416999 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt index 3b7b65e58143..ed95bf846cde 100644 --- a/mlir/lib/Dialect/Math/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt @@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRMathDialect MLIRArithDialect MLIRDialect MLIRIR + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp index 54a8cc1d697b..9cf47ac71306 100644 --- a/mlir/lib/Dialect/Math/IR/MathDialect.cpp +++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index ae9dc08c745b..28d1c062f235 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include @@ -522,5 +523,8 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) { Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt index 0189e79ea12f..2b5cedafae1e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -43,4 +43,5 @@ add_mlir_dialect_library(MLIRSPIRVDialect MLIRSideEffectInterfaces MLIRSupport MLIRTransforms + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index def62b4467ce..9acd982dc95a 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 76e703946428..a51d77dda78b 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -949,6 +950,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + if (!spirv::ConstantOp::isBuildableWith(type)) return nullptr; diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt index ba41f1aec8d9..32a86b483a49 100644 --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -23,4 +23,5 @@ add_mlir_dialect_library(MLIRShapeDialect MLIRIR MLIRSideEffectInterfaces MLIRTensorDialect + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index e4efa0931677..2444556a4563 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -147,6 +148,9 @@ void ShapeDialect::initialize() { Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + if (llvm::isa(type) || isExtentTensorType(type)) return builder.create( loc, type, llvm::cast(value)); @@ -156,6 +160,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, if (llvm::isa(type)) return builder.create(loc, type, llvm::cast(value)); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 0c8e0974b017..347b6346b786 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2584,3 +2584,58 @@ func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) { %select4 = arith.select %false, %poison, %arg : i32 return %select1, %select2, %select3, %select4 : i32, i32, i32, i32 } + +// CHECK-LABEL: @addi_poison1 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison1(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %0, %arg : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addi_poison2 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison2(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %arg, %0 : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addf_poison1 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison1(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %0, %arg : f32 + return %1 : f32 +} + +// CHECK-LABEL: @addf_poison2 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison2(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %arg, %0 : f32 + return %1 : f32 +} + + +// CHECK-LABEL: @negf_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @negf_poison() -> f32 { + %0 = ub.poison : f32 + %1 = arith.negf %0 : f32 + return %1 : f32 +} + +// CHECK-LABEL: @extsi_poison +// CHECK: %[[P:.*]] = ub.poison : i64 +// CHECK: return %[[P]] +func.func @extsi_poison() -> i64 { + %0 = ub.poison : i32 + %1 = arith.extsi %0 : i32 to i64 + return %1 : i64 +} diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir index 7a5194b89a5c..d24f7649269f 100644 --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -483,3 +483,12 @@ func.func @erf_fold_vec() -> (vector<4xf32>) { %0 = math.erf %v1 : vector<4xf32> return %0 : vector<4xf32> } + +// CHECK-LABEL: @abs_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @abs_poison() -> f32 { + %0 = ub.poison : f32 + %1 = math.absf %0 : f32 + return %1 : f32 +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 52607d726785..0200805a4443 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -325,6 +325,15 @@ func.func @const_fold_vector_iadd() -> vector<3xi32> { return %0: vector<3xi32> } +// CHECK-LABEL: @iadd_poison +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @iadd_poison(%arg0: i32) -> i32 { + %0 = ub.poison : i32 + %1 = spirv.IAdd %arg0, %0 : i32 + return %1: i32 +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index aec5f3202c9b..8edbae3baf52 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1479,3 +1479,16 @@ func.func @extract_shapeof(%arg0 : tensor) -> index { // CHECK: return %[[DIM]] return %result : index } + + +// ----- + +// CHECK-LABEL: @add_poison +// CHECK: %[[P:.*]] = ub.poison : !shape.siz +// CHECK: return %[[P]] +func.func @add_poison() -> !shape.size { + %1 = shape.const_size 2 + %2 = ub.poison : !shape.size + %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size + return %result : !shape.size +}