[mlir][math] Expand math.round to truncate, compare and increment.

Round functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will decompose the
roundf function by adding 0.5 to positive number to input
(subtracting for negative) following by a truncate.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D148026
This commit is contained in:
Balaji V. Iyer 2023-04-13 17:58:14 +00:00 committed by Robert Suderman
parent 43c42d6d7a
commit be9115788c
5 changed files with 95 additions and 2 deletions

View File

@ -20,6 +20,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {

View File

@ -174,6 +174,28 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
return success();
}
static LogicalResult convertRoundOp(math::RoundOp op,
PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
// Creating constants for later use.
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter);
Value posCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, zero);
Value incrValue =
b.create<arith::SelectOp>(op->getLoc(), posCheck, half, negHalf);
Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
Value fpFixedConvert = createTruncatedFPValue(add, b);
rewriter.replaceOp(op, fpFixedConvert);
return success();
}
// Converts math.ctlz to scf and arith operations. This is done
// by performing a binary search on the bits.
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@ -242,6 +264,10 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
patterns.add(convertExp2fOp);
}
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundOp);
}
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
patterns.add(convertFloorOp);
}

View File

@ -189,3 +189,21 @@ func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
%ret = math.exp2 %a : tensor<1xf32>
return %ret : tensor<1xf32>
}
// -----
// CHECK-LABEL: func @roundf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
func.func @roundf_func(%a: f64) -> f64 {
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 5.000000e-01
// CHECK-DAG: [[CST_1:%.+]] = arith.constant -5.000000e-01
// CHECK-DAG: [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]]
// CHECK-DAG: [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]]
// CHECK-DAG: [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]]
// CHECK-DAG: [[CVTI:%.+]] = arith.fptosi [[ADDF]]
// CHECK-DAG: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
// CHECK: return [[CVTF]]
%ret = math.round %a : f64
return %ret : f64
}

View File

@ -43,6 +43,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandFmaFPattern(patterns);
populateExpandFloorFPattern(patterns);
populateExpandCeilFPattern(patterns);
populateExpandRoundFPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

View File

@ -55,7 +55,54 @@ func.func @exp2f() {
return
}
func.func @main() {
call @exp2f() : () -> ()
// -------------------------------------------------------------------------- //
// round.
// -------------------------------------------------------------------------- //
func.func @func_roundf(%a : f32) {
%r = math.round %a : f32
vector.print %r : f32
return
}
func.func @roundf() {
// CHECK: 4
%a = arith.constant 3.8 : f32
call @func_roundf(%a) : (f32) -> ()
// CHECK: -4
%b = arith.constant -3.8 : f32
call @func_roundf(%b) : (f32) -> ()
// CHECK: 0
%c = arith.constant 0.0 : f32
call @func_roundf(%c) : (f32) -> ()
// CHECK: -4
%d = arith.constant -4.2 : f32
call @func_roundf(%d) : (f32) -> ()
// CHECK: -495
%e = arith.constant -495.0 : f32
call @func_roundf(%e) : (f32) -> ()
// CHECK: 495
%f = arith.constant 495.0 : f32
call @func_roundf(%f) : (f32) -> ()
// CHECK: 9
%g = arith.constant 8.5 : f32
call @func_roundf(%g) : (f32) -> ()
// CHECK: -9
%h = arith.constant -8.5 : f32
call @func_roundf(%h) : (f32) -> ()
return
}
func.func @main() {
call @exp2f() : () -> ()
call @roundf() : () -> ()
return
}