mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-04 03:44:59 +00:00
[mlir][math] Expand math.round to truncate, compare and increment.
Round functions are pushed directly to libm. This is problematic for situations where libm is not available. This patch will decompose the roundf function by adding 0.5 to positive number to input (subtracting for negative) following by a truncate. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D148026
This commit is contained in:
parent
43c42d6d7a
commit
be9115788c
@ -20,6 +20,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns);
|
||||
void populateExpandFloorFPattern(RewritePatternSet &patterns);
|
||||
void populateExpandCeilFPattern(RewritePatternSet &patterns);
|
||||
void populateExpandExp2FPattern(RewritePatternSet &patterns);
|
||||
void populateExpandRoundFPattern(RewritePatternSet &patterns);
|
||||
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
struct MathPolynomialApproximationOptions {
|
||||
|
@ -174,6 +174,28 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult convertRoundOp(math::RoundOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Value operand = op.getOperand();
|
||||
Type opType = operand.getType();
|
||||
|
||||
// Creating constants for later use.
|
||||
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
|
||||
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
|
||||
Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter);
|
||||
|
||||
Value posCheck =
|
||||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, zero);
|
||||
Value incrValue =
|
||||
b.create<arith::SelectOp>(op->getLoc(), posCheck, half, negHalf);
|
||||
Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
|
||||
|
||||
Value fpFixedConvert = createTruncatedFPValue(add, b);
|
||||
rewriter.replaceOp(op, fpFixedConvert);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts math.ctlz to scf and arith operations. This is done
|
||||
// by performing a binary search on the bits.
|
||||
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
|
||||
@ -242,6 +264,10 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
|
||||
patterns.add(convertExp2fOp);
|
||||
}
|
||||
|
||||
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
|
||||
patterns.add(convertRoundOp);
|
||||
}
|
||||
|
||||
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
|
||||
patterns.add(convertFloorOp);
|
||||
}
|
||||
|
@ -189,3 +189,21 @@ func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%ret = math.exp2 %a : tensor<1xf32>
|
||||
return %ret : tensor<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @roundf_func
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
|
||||
func.func @roundf_func(%a: f64) -> f64 {
|
||||
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
|
||||
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 5.000000e-01
|
||||
// CHECK-DAG: [[CST_1:%.+]] = arith.constant -5.000000e-01
|
||||
// CHECK-DAG: [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]]
|
||||
// CHECK-DAG: [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]]
|
||||
// CHECK-DAG: [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]]
|
||||
// CHECK-DAG: [[CVTI:%.+]] = arith.fptosi [[ADDF]]
|
||||
// CHECK-DAG: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
|
||||
// CHECK: return [[CVTF]]
|
||||
%ret = math.round %a : f64
|
||||
return %ret : f64
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ void TestExpandMathPass::runOnOperation() {
|
||||
populateExpandFmaFPattern(patterns);
|
||||
populateExpandFloorFPattern(patterns);
|
||||
populateExpandCeilFPattern(patterns);
|
||||
populateExpandRoundFPattern(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,54 @@ func.func @exp2f() {
|
||||
return
|
||||
}
|
||||
|
||||
func.func @main() {
|
||||
call @exp2f() : () -> ()
|
||||
// -------------------------------------------------------------------------- //
|
||||
// round.
|
||||
// -------------------------------------------------------------------------- //
|
||||
func.func @func_roundf(%a : f32) {
|
||||
%r = math.round %a : f32
|
||||
vector.print %r : f32
|
||||
return
|
||||
}
|
||||
|
||||
func.func @roundf() {
|
||||
// CHECK: 4
|
||||
%a = arith.constant 3.8 : f32
|
||||
call @func_roundf(%a) : (f32) -> ()
|
||||
|
||||
// CHECK: -4
|
||||
%b = arith.constant -3.8 : f32
|
||||
call @func_roundf(%b) : (f32) -> ()
|
||||
|
||||
// CHECK: 0
|
||||
%c = arith.constant 0.0 : f32
|
||||
call @func_roundf(%c) : (f32) -> ()
|
||||
|
||||
// CHECK: -4
|
||||
%d = arith.constant -4.2 : f32
|
||||
call @func_roundf(%d) : (f32) -> ()
|
||||
|
||||
// CHECK: -495
|
||||
%e = arith.constant -495.0 : f32
|
||||
call @func_roundf(%e) : (f32) -> ()
|
||||
|
||||
// CHECK: 495
|
||||
%f = arith.constant 495.0 : f32
|
||||
call @func_roundf(%f) : (f32) -> ()
|
||||
|
||||
// CHECK: 9
|
||||
%g = arith.constant 8.5 : f32
|
||||
call @func_roundf(%g) : (f32) -> ()
|
||||
|
||||
// CHECK: -9
|
||||
%h = arith.constant -8.5 : f32
|
||||
call @func_roundf(%h) : (f32) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
func.func @main() {
|
||||
call @exp2f() : () -> ()
|
||||
call @roundf() : () -> ()
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user