mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-11 17:08:42 +00:00
[mlir][math] Expand math.floorf to truncate, compares and increments
Floorf are pushed directly to libm. This is problematic for situations where libm is not available. This patch will break down a floorf function to truncate followed by an increment for negative values, if necessary. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D147966
This commit is contained in:
parent
25350a7362
commit
af9eb1e384
@ -17,7 +17,7 @@ void populateExpandCtlzPattern(RewritePatternSet &patterns);
|
||||
void populateExpandTanPattern(RewritePatternSet &patterns);
|
||||
void populateExpandTanhPattern(RewritePatternSet &patterns);
|
||||
void populateExpandFmaFPattern(RewritePatternSet &patterns);
|
||||
|
||||
void populateExpandFloorFPattern(RewritePatternSet &patterns);
|
||||
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
struct MathPolynomialApproximationOptions {
|
||||
|
@ -102,6 +102,32 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts a floorf() function to the following:
|
||||
// floorf(float x) ->
|
||||
// y = (float)(int) x
|
||||
// if (x < 0) then incr = -1 else incr = 0
|
||||
// y = y + incr <= replace this op with the floorf op.
|
||||
static LogicalResult convertFloorOp(math::FloorOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Value operand = op.getOperand();
|
||||
Type opType = operand.getType();
|
||||
Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
|
||||
Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
|
||||
|
||||
// Creating constants for later use.
|
||||
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
|
||||
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
|
||||
|
||||
Value negCheck =
|
||||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
|
||||
Value incrValue =
|
||||
b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
|
||||
Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts math.ctlz to scf and arith operations. This is done
|
||||
// by performing a binary search on the bits.
|
||||
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
|
||||
@ -161,3 +187,6 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
|
||||
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
|
||||
patterns.add(convertFmaFOp);
|
||||
}
|
||||
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
|
||||
patterns.add(convertFloorOp);
|
||||
}
|
||||
|
@ -131,3 +131,20 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
|
||||
%ret = math.fma %a, %b, %c : f64
|
||||
return %ret : f64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @floorf_func
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
|
||||
func.func @floorf_func(%a: f64) -> f64 {
|
||||
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
|
||||
// CHECK-DAG: [[CST_0:%.+]] = arith.constant -1.000
|
||||
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
|
||||
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
|
||||
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf olt, [[ARG0]], [[CST]]
|
||||
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
|
||||
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
|
||||
// CHECK-NEXT: return [[ADDF]]
|
||||
%ret = math.floor %a : f64
|
||||
return %ret : f64
|
||||
}
|
||||
|
@ -40,6 +40,7 @@ void TestExpandMathPass::runOnOperation() {
|
||||
populateExpandTanPattern(patterns);
|
||||
populateExpandTanhPattern(patterns);
|
||||
populateExpandFmaFPattern(patterns);
|
||||
populateExpandFloorFPattern(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
|
@ -610,6 +610,43 @@ func.func @cbrt() {
|
||||
return
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// floor.
|
||||
// -------------------------------------------------------------------------- //
|
||||
func.func @func_floorf32(%a : f32) {
|
||||
%r = math.floor %a : f32
|
||||
vector.print %r : f32
|
||||
return
|
||||
}
|
||||
|
||||
func.func @floorf() {
|
||||
// CHECK: 3
|
||||
%a = arith.constant 3.8 : f32
|
||||
call @func_floorf32(%a) : (f32) -> ()
|
||||
|
||||
// CHECK: -4
|
||||
%b = arith.constant -3.8 : f32
|
||||
call @func_floorf32(%b) : (f32) -> ()
|
||||
|
||||
// CHECK: 0
|
||||
%c = arith.constant 0.0 : f32
|
||||
call @func_floorf32(%c) : (f32) -> ()
|
||||
|
||||
// CHECK: -5
|
||||
%d = arith.constant -4.2 : f32
|
||||
call @func_floorf32(%d) : (f32) -> ()
|
||||
|
||||
// CHECK: -2
|
||||
%e = arith.constant -2.0 : f32
|
||||
call @func_floorf32(%e) : (f32) -> ()
|
||||
|
||||
// CHECK: 2
|
||||
%f = arith.constant 2.0 : f32
|
||||
call @func_floorf32(%f) : (f32) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func.func @main() {
|
||||
call @tanh(): () -> ()
|
||||
call @log(): () -> ()
|
||||
@ -623,6 +660,7 @@ func.func @main() {
|
||||
call @atan() : () -> ()
|
||||
call @atan2() : () -> ()
|
||||
call @cbrt() : () -> ()
|
||||
call @floorf() : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user