[mlir][spirv] Add some folders for spv.LogicalAnd/spv.LogicalOr

This commit handles folding spv.LogicalAnd/spv.LogicalOr when
one of the operands is constant true/false.

Differential Revision: https://reviews.llvm.org/D75195
This commit is contained in:
Lei Zhang 2020-02-26 12:47:02 -05:00
parent ca50f09db9
commit 5bc6ff6455
3 changed files with 120 additions and 0 deletions
mlir
include/mlir/Dialect/SPIRV
lib/Dialect/SPIRV
test/Dialect/SPIRV

@ -526,6 +526,8 @@ def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]
%2 = spv.LogicalAnd %0, %1 : vector<4xi1>
```
}];
let hasFolder = 1;
}
// -----
@ -656,6 +658,8 @@ def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]>
%2 = spv.LogicalOr %0, %1 : vector<4xi1>
```
}];
let hasFolder = 1;
}
// -----

@ -24,6 +24,26 @@ using namespace mlir;
// Common utility functions
//===----------------------------------------------------------------------===//
/// Returns true if the given `irVal` is a scalar or splat vector constant of
/// the given `boolVal`.
static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
if (!boolAttr)
return false;
auto type = boolAttr.getType();
if (type.isInteger(1)) {
auto attr = boolAttr.cast<BoolAttr>();
return attr.getValue() == boolVal;
}
if (auto vecType = type.cast<VectorType>()) {
if (vecType.getElementType().isInteger(1))
if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
boolVal;
}
return false;
}
// Extracts an element from the given `composite` by following the given
// `indices`. Returns a null Attribute if error happens.
static Attribute extractCompositeElement(Attribute composite,
@ -187,6 +207,24 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
[](APInt a, APInt b) { return a - b; });
}
//===----------------------------------------------------------------------===//
// spv.LogicalAnd
//===----------------------------------------------------------------------===//
OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
// x && true = x
if (isScalarOrSplatBoolAttr(operands.back(), true))
return operand1();
// x && false = false
if (isScalarOrSplatBoolAttr(operands.back(), false))
return operands.back();
return Attribute();
}
//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//
@ -198,6 +236,24 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
ConvertLogicalNotOfLogicalNotEqual>(context);
}
//===----------------------------------------------------------------------===//
// spv.LogicalOr
//===----------------------------------------------------------------------===//
OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
// x || true = true
if (isScalarOrSplatBoolAttr(operands.back(), true))
return operands.back();
// x || false = x
if (isScalarOrSplatBoolAttr(operands.back(), false))
return operand1();
return Attribute();
}
//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//

@ -362,6 +362,36 @@ func @const_fold_vector_isub() -> vector<3xi32> {
// -----
//===----------------------------------------------------------------------===//
// spv.LogicalAnd
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @convert_logical_and_true_false_scalar
// CHECK-SAME: %[[ARG:.+]]: i1
func @convert_logical_and_true_false_scalar(%arg: i1) -> (i1, i1) {
%true = spv.constant true
// CHECK: %[[FALSE:.+]] = spv.constant false
%false = spv.constant false
%0 = spv.LogicalAnd %true, %arg: i1
%1 = spv.LogicalAnd %arg, %false: i1
// CHECK: return %[[ARG]], %[[FALSE]]
return %0, %1: i1, i1
}
// CHECK-LABEL: @convert_logical_and_true_false_vector
// CHECK-SAME: %[[ARG:.+]]: vector<3xi1>
func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
%true = spv.constant dense<true> : vector<3xi1>
// CHECK: %[[FALSE:.+]] = spv.constant dense<false>
%false = spv.constant dense<false> : vector<3xi1>
%0 = spv.LogicalAnd %true, %arg: vector<3xi1>
%1 = spv.LogicalAnd %arg, %false: vector<3xi1>
// CHECK: return %[[ARG]], %[[FALSE]]
return %0, %1: vector<3xi1>, vector<3xi1>
}
// -----
//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//
@ -419,6 +449,36 @@ func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3x
// -----
//===----------------------------------------------------------------------===//
// spv.LogicalOr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @convert_logical_or_true_false_scalar
// CHECK-SAME: %[[ARG:.+]]: i1
func @convert_logical_or_true_false_scalar(%arg: i1) -> (i1, i1) {
// CHECK: %[[TRUE:.+]] = spv.constant true
%true = spv.constant true
%false = spv.constant false
%0 = spv.LogicalOr %true, %arg: i1
%1 = spv.LogicalOr %arg, %false: i1
// CHECK: return %[[TRUE]], %[[ARG]]
return %0, %1: i1, i1
}
// CHECK-LABEL: @convert_logical_or_true_false_vector
// CHECK-SAME: %[[ARG:.+]]: vector<3xi1>
func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
// CHECK: %[[TRUE:.+]] = spv.constant dense<true>
%true = spv.constant dense<true> : vector<3xi1>
%false = spv.constant dense<false> : vector<3xi1>
%0 = spv.LogicalOr %true, %arg: vector<3xi1>
%1 = spv.LogicalOr %arg, %false: vector<3xi1>
// CHECK: return %[[TRUE]], %[[ARG]]
return %0, %1: vector<3xi1>, vector<3xi1>
}
// -----
//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//