diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index 77f25ac935a8..2775f778719a 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -512,6 +512,35 @@ public: } }; +/// Converts std.trunci to spv.Select if the type of result is i1 or vector of +/// i1. +class TruncI1Pattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TruncateIOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!isBoolScalarOrVector(dstType)) + return failure(); + + Location loc = op.getLoc(); + auto srcType = operands.front().getType(); + // Check if (x & 1) == 1. + Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); + Value maskedSrc = + rewriter.create(loc, srcType, operands[0], mask); + Value isOne = rewriter.create(loc, maskedSrc, mask); + + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); + return success(); + } +}; + /// Converts std.uitofp to spv.Select if the type of source is i1 or vector of /// i1. class UIToFPI1Pattern final : public OpConversionPattern { @@ -547,10 +576,10 @@ public: ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1); auto srcType = operands.front().getType(); - if (isBoolScalarOrVector(srcType)) - return failure(); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); + if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) + return failure(); if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. @@ -1178,7 +1207,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, ReturnOpPattern, SelectOpPattern, // Type cast patterns - UIToFPI1Pattern, ZeroExtendI1Pattern, + UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index 1dc3678bfebf..6bb6d78d56af 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -744,6 +744,30 @@ func @trunci2(%arg0: i32) -> i16 { return %0 : i16 } +// CHECK-LABEL: @trunc_to_i1 +func @trunc_to_i1(%arg0: i32) -> i1 { + // CHECK: %[[MASK:.*]] = spv.constant 1 : i32 + // CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : i32 + // CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : i32 + // CHECK-DAG: %[[TRUE:.*]] = spv.constant true + // CHECK-DAG: %[[FALSE:.*]] = spv.constant false + // CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : i1, i1 + %0 = std.trunci %arg0 : i32 to i1 + return %0 : i1 +} + +// CHECK-LABEL: @trunc_to_veci1 +func @trunc_to_veci1(%arg0: vector<4xi32>) -> vector<4xi1> { + // CHECK: %[[MASK:.*]] = spv.constant dense<1> : vector<4xi32> + // CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : vector<4xi32> + // CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : vector<4xi32> + // CHECK-DAG: %[[TRUE:.*]] = spv.constant dense : vector<4xi1> + // CHECK-DAG: %[[FALSE:.*]] = spv.constant dense : vector<4xi1> + // CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : vector<4xi1>, vector<4xi1> + %0 = std.trunci %arg0 : vector<4xi32> to vector<4xi1> + return %0 : vector<4xi1> +} + // CHECK-LABEL: @fptosi1 func @fptosi1(%arg0 : f32) -> i32 { // CHECK: spv.ConvertFToS %{{.*}} : f32 to i32