mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-19 01:09:39 +00:00
[mlir][spirv] Add some folders for spv.LogicalAnd/spv.LogicalOr
This commit handles folding spv.LogicalAnd/spv.LogicalOr when one of the operands is constant true/false. Differential Revision: https://reviews.llvm.org/D75195
This commit is contained in:
parent
ca50f09db9
commit
5bc6ff6455
mlir
include/mlir/Dialect/SPIRV
lib/Dialect/SPIRV
test/Dialect/SPIRV
@ -526,6 +526,8 @@ def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]
|
||||
%2 = spv.LogicalAnd %0, %1 : vector<4xi1>
|
||||
```
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -656,6 +658,8 @@ def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]>
|
||||
%2 = spv.LogicalOr %0, %1 : vector<4xi1>
|
||||
```
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -24,6 +24,26 @@ using namespace mlir;
|
||||
// Common utility functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if the given `irVal` is a scalar or splat vector constant of
|
||||
/// the given `boolVal`.
|
||||
static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
|
||||
if (!boolAttr)
|
||||
return false;
|
||||
|
||||
auto type = boolAttr.getType();
|
||||
if (type.isInteger(1)) {
|
||||
auto attr = boolAttr.cast<BoolAttr>();
|
||||
return attr.getValue() == boolVal;
|
||||
}
|
||||
if (auto vecType = type.cast<VectorType>()) {
|
||||
if (vecType.getElementType().isInteger(1))
|
||||
if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
|
||||
return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
|
||||
boolVal;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Extracts an element from the given `composite` by following the given
|
||||
// `indices`. Returns a null Attribute if error happens.
|
||||
static Attribute extractCompositeElement(Attribute composite,
|
||||
@ -187,6 +207,24 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
|
||||
[](APInt a, APInt b) { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.LogicalAnd
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
|
||||
|
||||
// x && true = x
|
||||
if (isScalarOrSplatBoolAttr(operands.back(), true))
|
||||
return operand1();
|
||||
|
||||
// x && false = false
|
||||
if (isScalarOrSplatBoolAttr(operands.back(), false))
|
||||
return operands.back();
|
||||
|
||||
return Attribute();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.LogicalNot
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -198,6 +236,24 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
|
||||
ConvertLogicalNotOfLogicalNotEqual>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.LogicalOr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
|
||||
|
||||
// x || true = true
|
||||
if (isScalarOrSplatBoolAttr(operands.back(), true))
|
||||
return operands.back();
|
||||
|
||||
// x || false = x
|
||||
if (isScalarOrSplatBoolAttr(operands.back(), false))
|
||||
return operand1();
|
||||
|
||||
return Attribute();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.selection
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -362,6 +362,36 @@ func @const_fold_vector_isub() -> vector<3xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.LogicalAnd
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @convert_logical_and_true_false_scalar
|
||||
// CHECK-SAME: %[[ARG:.+]]: i1
|
||||
func @convert_logical_and_true_false_scalar(%arg: i1) -> (i1, i1) {
|
||||
%true = spv.constant true
|
||||
// CHECK: %[[FALSE:.+]] = spv.constant false
|
||||
%false = spv.constant false
|
||||
%0 = spv.LogicalAnd %true, %arg: i1
|
||||
%1 = spv.LogicalAnd %arg, %false: i1
|
||||
// CHECK: return %[[ARG]], %[[FALSE]]
|
||||
return %0, %1: i1, i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_logical_and_true_false_vector
|
||||
// CHECK-SAME: %[[ARG:.+]]: vector<3xi1>
|
||||
func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
|
||||
%true = spv.constant dense<true> : vector<3xi1>
|
||||
// CHECK: %[[FALSE:.+]] = spv.constant dense<false>
|
||||
%false = spv.constant dense<false> : vector<3xi1>
|
||||
%0 = spv.LogicalAnd %true, %arg: vector<3xi1>
|
||||
%1 = spv.LogicalAnd %arg, %false: vector<3xi1>
|
||||
// CHECK: return %[[ARG]], %[[FALSE]]
|
||||
return %0, %1: vector<3xi1>, vector<3xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.LogicalNot
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -419,6 +449,36 @@ func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3x
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.LogicalOr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @convert_logical_or_true_false_scalar
|
||||
// CHECK-SAME: %[[ARG:.+]]: i1
|
||||
func @convert_logical_or_true_false_scalar(%arg: i1) -> (i1, i1) {
|
||||
// CHECK: %[[TRUE:.+]] = spv.constant true
|
||||
%true = spv.constant true
|
||||
%false = spv.constant false
|
||||
%0 = spv.LogicalOr %true, %arg: i1
|
||||
%1 = spv.LogicalOr %arg, %false: i1
|
||||
// CHECK: return %[[TRUE]], %[[ARG]]
|
||||
return %0, %1: i1, i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_logical_or_true_false_vector
|
||||
// CHECK-SAME: %[[ARG:.+]]: vector<3xi1>
|
||||
func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
|
||||
// CHECK: %[[TRUE:.+]] = spv.constant dense<true>
|
||||
%true = spv.constant dense<true> : vector<3xi1>
|
||||
%false = spv.constant dense<false> : vector<3xi1>
|
||||
%0 = spv.LogicalOr %true, %arg: vector<3xi1>
|
||||
%1 = spv.LogicalOr %arg, %false: vector<3xi1>
|
||||
// CHECK: return %[[TRUE]], %[[ARG]]
|
||||
return %0, %1: vector<3xi1>, vector<3xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.selection
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
x
Reference in New Issue
Block a user