[mlir][StandardToSPIRV] Add support for lowering trunci to SPIR-V to i1 types.

Add a pattern to converting some value to a boolean. spirv.S/UConvert does not
work on i1 types. Thus, the pattern is lowered to cmpi + select.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D96851
This commit is contained in:
Hanhan Wang 2021-02-17 06:55:10 -08:00
parent 8bcc03767e
commit c80484e16e
2 changed files with 56 additions and 3 deletions

View File

@ -512,6 +512,35 @@ public:
}
};
/// Converts std.trunci to spv.Select if the type of result is i1 or vector of
/// i1.
class TruncI1Pattern final : public OpConversionPattern<TruncateIOp> {
public:
using OpConversionPattern<TruncateIOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(TruncateIOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType =
this->getTypeConverter()->convertType(op.getResult().getType());
if (!isBoolScalarOrVector(dstType))
return failure();
Location loc = op.getLoc();
auto srcType = operands.front().getType();
// Check if (x & 1) == 1.
Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
Value maskedSrc =
rewriter.create<spirv::BitwiseAndOp>(loc, srcType, operands[0], mask);
Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
return success();
}
};
/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
/// i1.
class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
@ -547,10 +576,10 @@ public:
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1);
auto srcType = operands.front().getType();
if (isBoolScalarOrVector(srcType))
return failure();
auto dstType =
this->getTypeConverter()->convertType(operation.getResult().getType());
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
return failure();
if (dstType == srcType) {
// Due to type conversion, we are seeing the same source and target type.
// Then we can just erase this operation by forwarding its operand.
@ -1178,7 +1207,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
ReturnOpPattern, SelectOpPattern,
// Type cast patterns
UIToFPI1Pattern, ZeroExtendI1Pattern,
UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern,
TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,

View File

@ -744,6 +744,30 @@ func @trunci2(%arg0: i32) -> i16 {
return %0 : i16
}
// CHECK-LABEL: @trunc_to_i1
func @trunc_to_i1(%arg0: i32) -> i1 {
// CHECK: %[[MASK:.*]] = spv.constant 1 : i32
// CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : i32
// CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : i32
// CHECK-DAG: %[[TRUE:.*]] = spv.constant true
// CHECK-DAG: %[[FALSE:.*]] = spv.constant false
// CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : i1, i1
%0 = std.trunci %arg0 : i32 to i1
return %0 : i1
}
// CHECK-LABEL: @trunc_to_veci1
func @trunc_to_veci1(%arg0: vector<4xi32>) -> vector<4xi1> {
// CHECK: %[[MASK:.*]] = spv.constant dense<1> : vector<4xi32>
// CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : vector<4xi32>
// CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : vector<4xi32>
// CHECK-DAG: %[[TRUE:.*]] = spv.constant dense<true> : vector<4xi1>
// CHECK-DAG: %[[FALSE:.*]] = spv.constant dense<false> : vector<4xi1>
// CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : vector<4xi1>, vector<4xi1>
%0 = std.trunci %arg0 : vector<4xi32> to vector<4xi1>
return %0 : vector<4xi1>
}
// CHECK-LABEL: @fptosi1
func @fptosi1(%arg0 : f32) -> i32 {
// CHECK: spv.ConvertFToS %{{.*}} : f32 to i32