[mlir][ub] Add poison support to CommonFolders.h

Return poison from foldBinary/unary if argument(s) is poison. Add ub dialect as dependency to affected dialects (arith, math, spirv, shape).
Add poison materialization to dialects. Add tests for some ops from each dialect.
Not all affected ops are covered as it will involve a huge copypaste.

Differential Revision: https://reviews.llvm.org/D159013
This commit is contained in:
Ivan Butygin 2023-08-28 20:51:37 +02:00
parent 0dd4d3b5cc
commit 5dce74817b
15 changed files with 179 additions and 15 deletions

View File

@ -22,17 +22,35 @@
#include <optional>
namespace mlir {
namespace ub {
class PoisonAttr;
}
/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
/// Uses `resultType` for the type of the returned attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
const CalculationT &calculate) {
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
static_assert(
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
"void as template argument to opt-out from poison semantics.");
if constexpr (!std::is_void_v<PoisonAttr>) {
if (isa_and_nonnull<PoisonAttr>(operands[0]))
return operands[0];
if (isa_and_nonnull<PoisonAttr>(operands[1]))
return operands[1];
}
if (!resultType || !operands[0] || !operands[1])
return {};
@ -95,13 +113,28 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
/// attributes in `operands` and returns the result if possible.
/// Uses the operand element type for the element type of the returned
/// attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
static_assert(
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
"void as template argument to opt-out from poison semantics.");
if constexpr (!std::is_void_v<PoisonAttr>) {
if (isa_and_nonnull<PoisonAttr>(operands[0]))
return operands[0];
if (isa_and_nonnull<PoisonAttr>(operands[1]))
return operands[1];
}
auto getResultType = [](Attribute attr) -> Type {
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
return typed.getType();
@ -115,18 +148,19 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (lhsType != rhsType)
return {};
return constFoldBinaryOpConditional<AttrElementT, ElementValueT,
CalculationT>(operands, lhsType,
calculate);
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
CalculationT>(
operands, lhsType, std::forward<CalculationT>(calculate));
}
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = void,
class CalculationT =
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
const CalculationT &calculate) {
return constFoldBinaryOpConditional<AttrElementT>(
CalculationT &&calculate) {
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands, resultType,
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
return calculate(a, b);
@ -135,11 +169,12 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT =
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
return constFoldBinaryOpConditional<AttrElementT>(
CalculationT &&calculate) {
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands,
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
return calculate(a, b);
@ -148,16 +183,28 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
/// Performs constant folding `calculate` with element-wise behavior on the one
/// attributes in `operands` and returns the result if possible.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT =
function_ref<std::optional<ElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
const CalculationT &&calculate) {
CalculationT &&calculate) {
assert(operands.size() == 1 && "unary op takes one operands");
if (!operands[0])
return {};
static_assert(
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
"void as template argument to opt-out from poison semantics.");
if constexpr (!std::is_void_v<PoisonAttr>) {
if (isa<PoisonAttr>(operands[0]))
return operands[0];
}
if (isa<AttrElementT>(operands[0])) {
auto op = cast<AttrElementT>(operands[0]);
@ -196,10 +243,11 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<ElementValueT(ElementValueT)>>
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
const CalculationT &&calculate) {
return constFoldUnaryOpConditional<AttrElementT>(
CalculationT &&calculate) {
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
return calculate(a);
});
@ -209,13 +257,23 @@ template <
class AttrElementT, class TargetAttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class TargetElementValueT = typename TargetAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
const CalculationT &calculate) {
CalculationT &&calculate) {
assert(operands.size() == 1 && "Cast op takes one operand");
if (!operands[0])
return {};
static_assert(
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
"void as template argument to opt-out from poison semantics.");
if constexpr (!std::is_void_v<PoisonAttr>) {
if (isa<PoisonAttr>(operands[0]))
return operands[0];
}
if (isa<AttrElementT>(operands[0])) {
auto op = cast<AttrElementT>(operands[0]);
bool castStatus = true;
@ -254,7 +312,6 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
}
return {};
}
} // namespace mlir
#endif // MLIR_DIALECT_COMMONFOLDERS_H

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
@ -49,5 +50,8 @@ void arith::ArithDialect::initialize() {
Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
return builder.create<ub::PoisonOp>(loc, type, poison);
return ConstantOp::materialize(builder, value, type, loc);
}

View File

@ -9,7 +9,6 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"

View File

@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRMathDialect
MLIRArithDialect
MLIRDialect
MLIRIR
MLIRUBDialect
)

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include <optional>
@ -522,5 +523,8 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
return builder.create<ub::PoisonOp>(loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}

View File

@ -43,4 +43,5 @@ add_mlir_dialect_library(MLIRSPIRVDialect
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransforms
MLIRUBDialect
)

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
@ -949,6 +950,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
return builder.create<ub::PoisonOp>(loc, type, poison);
if (!spirv::ConstantOp::isBuildableWith(type))
return nullptr;

View File

@ -23,4 +23,5 @@ add_mlir_dialect_library(MLIRShapeDialect
MLIRIR
MLIRSideEffectInterfaces
MLIRTensorDialect
MLIRUBDialect
)

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
@ -147,6 +148,9 @@ void ShapeDialect::initialize() {
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
return builder.create<ub::PoisonOp>(loc, type, poison);
if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
return builder.create<ConstShapeOp>(
loc, type, llvm::cast<DenseIntElementsAttr>(value));
@ -156,6 +160,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
if (llvm::isa<WitnessType>(type))
return builder.create<ConstWitnessOp>(loc, type,
llvm::cast<BoolAttr>(value));
return arith::ConstantOp::materialize(builder, value, type, loc);
}

View File

@ -2584,3 +2584,58 @@ func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) {
%select4 = arith.select %false, %poison, %arg : i32
return %select1, %select2, %select3, %select4 : i32, i32, i32, i32
}
// CHECK-LABEL: @addi_poison1
// CHECK: %[[P:.*]] = ub.poison : i32
// CHECK: return %[[P]]
func.func @addi_poison1(%arg: i32) -> i32 {
%0 = ub.poison : i32
%1 = arith.addi %0, %arg : i32
return %1 : i32
}
// CHECK-LABEL: @addi_poison2
// CHECK: %[[P:.*]] = ub.poison : i32
// CHECK: return %[[P]]
func.func @addi_poison2(%arg: i32) -> i32 {
%0 = ub.poison : i32
%1 = arith.addi %arg, %0 : i32
return %1 : i32
}
// CHECK-LABEL: @addf_poison1
// CHECK: %[[P:.*]] = ub.poison : f32
// CHECK: return %[[P]]
func.func @addf_poison1(%arg: f32) -> f32 {
%0 = ub.poison : f32
%1 = arith.addf %0, %arg : f32
return %1 : f32
}
// CHECK-LABEL: @addf_poison2
// CHECK: %[[P:.*]] = ub.poison : f32
// CHECK: return %[[P]]
func.func @addf_poison2(%arg: f32) -> f32 {
%0 = ub.poison : f32
%1 = arith.addf %arg, %0 : f32
return %1 : f32
}
// CHECK-LABEL: @negf_poison
// CHECK: %[[P:.*]] = ub.poison : f32
// CHECK: return %[[P]]
func.func @negf_poison() -> f32 {
%0 = ub.poison : f32
%1 = arith.negf %0 : f32
return %1 : f32
}
// CHECK-LABEL: @extsi_poison
// CHECK: %[[P:.*]] = ub.poison : i64
// CHECK: return %[[P]]
func.func @extsi_poison() -> i64 {
%0 = ub.poison : i32
%1 = arith.extsi %0 : i32 to i64
return %1 : i64
}

View File

@ -483,3 +483,12 @@ func.func @erf_fold_vec() -> (vector<4xf32>) {
%0 = math.erf %v1 : vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: @abs_poison
// CHECK: %[[P:.*]] = ub.poison : f32
// CHECK: return %[[P]]
func.func @abs_poison() -> f32 {
%0 = ub.poison : f32
%1 = math.absf %0 : f32
return %1 : f32
}

View File

@ -325,6 +325,15 @@ func.func @const_fold_vector_iadd() -> vector<3xi32> {
return %0: vector<3xi32>
}
// CHECK-LABEL: @iadd_poison
// CHECK: %[[P:.*]] = ub.poison : i32
// CHECK: return %[[P]]
func.func @iadd_poison(%arg0: i32) -> i32 {
%0 = ub.poison : i32
%1 = spirv.IAdd %arg0, %0 : i32
return %1: i32
}
// -----
//===----------------------------------------------------------------------===//

View File

@ -1479,3 +1479,16 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
// CHECK: return %[[DIM]]
return %result : index
}
// -----
// CHECK-LABEL: @add_poison
// CHECK: %[[P:.*]] = ub.poison : !shape.siz
// CHECK: return %[[P]]
func.func @add_poison() -> !shape.size {
%1 = shape.const_size 2
%2 = ub.poison : !shape.size
%result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
return %result : !shape.size
}