[mlir][math] Expand math.floorf to truncate, compares and increments

Floorf are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will break down
a floorf function to truncate followed by an increment for negative
values, if necessary.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D147966
This commit is contained in:
Balaji V. Iyer 2023-04-10 21:01:26 +00:00 committed by Robert Suderman
parent 25350a7362
commit af9eb1e384
5 changed files with 86 additions and 1 deletions

View File

@ -17,7 +17,7 @@ void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {

View File

@ -102,6 +102,32 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
return success();
}
// Converts a floorf() function to the following:
// floorf(float x) ->
// y = (float)(int) x
// if (x < 0) then incr = -1 else incr = 0
// y = y + incr <= replace this op with the floorf op.
static LogicalResult convertFloorOp(math::FloorOp op,
PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
// Creating constants for later use.
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
Value negCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
Value incrValue =
b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
rewriter.replaceOp(op, ret);
return success();
}
// Converts math.ctlz to scf and arith operations. This is done
// by performing a binary search on the bits.
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@ -161,3 +187,6 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
patterns.add(convertFmaFOp);
}
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
patterns.add(convertFloorOp);
}

View File

@ -131,3 +131,20 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
%ret = math.fma %a, %b, %c : f64
return %ret : f64
}
// -----
// CHECK-LABEL: func @floorf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
func.func @floorf_func(%a: f64) -> f64 {
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
// CHECK-DAG: [[CST_0:%.+]] = arith.constant -1.000
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf olt, [[ARG0]], [[CST]]
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
// CHECK-NEXT: return [[ADDF]]
%ret = math.floor %a : f64
return %ret : f64
}

View File

@ -40,6 +40,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandTanPattern(patterns);
populateExpandTanhPattern(patterns);
populateExpandFmaFPattern(patterns);
populateExpandFloorFPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

View File

@ -610,6 +610,43 @@ func.func @cbrt() {
return
}
// -------------------------------------------------------------------------- //
// floor.
// -------------------------------------------------------------------------- //
func.func @func_floorf32(%a : f32) {
%r = math.floor %a : f32
vector.print %r : f32
return
}
func.func @floorf() {
// CHECK: 3
%a = arith.constant 3.8 : f32
call @func_floorf32(%a) : (f32) -> ()
// CHECK: -4
%b = arith.constant -3.8 : f32
call @func_floorf32(%b) : (f32) -> ()
// CHECK: 0
%c = arith.constant 0.0 : f32
call @func_floorf32(%c) : (f32) -> ()
// CHECK: -5
%d = arith.constant -4.2 : f32
call @func_floorf32(%d) : (f32) -> ()
// CHECK: -2
%e = arith.constant -2.0 : f32
call @func_floorf32(%e) : (f32) -> ()
// CHECK: 2
%f = arith.constant 2.0 : f32
call @func_floorf32(%f) : (f32) -> ()
return
}
func.func @main() {
call @tanh(): () -> ()
call @log(): () -> ()
@ -623,6 +660,7 @@ func.func @main() {
call @atan() : () -> ()
call @atan2() : () -> ()
call @cbrt() : () -> ()
call @floorf() : () -> ()
return
}