mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-04 03:44:59 +00:00
Fix handling of special and large vals in expand pattern for round
The current expand pattern for `math.round` does not handle the special values -0.0, +-inf, and +-nan correctly. It also does not properly handle values with magnitude |x| >= 2^23. Lastly, the pattern generates invalid IR when the input to `math.round` is a vector. This patch fixes these issues. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D148398
This commit is contained in:
parent
bbc983d33a
commit
ab2fc9521e
@ -48,9 +48,14 @@ static Value createIntConst(Location loc, Type type, int64_t value,
|
||||
|
||||
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
|
||||
Type opType = operand.getType();
|
||||
Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
|
||||
Type i64Ty = b.getI64Type();
|
||||
if (auto shapedTy = dyn_cast<ShapedType>(opType))
|
||||
i64Ty = shapedTy.clone(i64Ty);
|
||||
Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
|
||||
Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
|
||||
return fpFixedConvert;
|
||||
// The truncation does not preserve the sign when the truncated
|
||||
// value is -0. So here the sign is copied again.
|
||||
return b.create<math::CopySignOp>(fpFixedConvert, operand);
|
||||
}
|
||||
|
||||
/// Expands tanh op into
|
||||
@ -189,23 +194,59 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
|
||||
|
||||
static LogicalResult convertRoundOp(math::RoundOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Location loc = op.getLoc();
|
||||
ImplicitLocOpBuilder b(loc, rewriter);
|
||||
Value operand = op.getOperand();
|
||||
Type opType = operand.getType();
|
||||
Type opEType = getElementTypeOrSelf(opType);
|
||||
|
||||
// Creating constants for later use.
|
||||
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
|
||||
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
|
||||
Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter);
|
||||
if (!opEType.isF32()) {
|
||||
return rewriter.notifyMatchFailure(op, "not a round of f32.");
|
||||
}
|
||||
|
||||
Value posCheck =
|
||||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, zero);
|
||||
Value incrValue =
|
||||
b.create<arith::SelectOp>(op->getLoc(), posCheck, half, negHalf);
|
||||
Type i32Ty = b.getI32Type();
|
||||
if (auto shapedTy = dyn_cast<ShapedType>(opType))
|
||||
i32Ty = shapedTy.clone(i32Ty);
|
||||
|
||||
Value half = createFloatConst(loc, opType, 0.5, b);
|
||||
Value c23 = createIntConst(loc, i32Ty, 23, b);
|
||||
Value c127 = createIntConst(loc, i32Ty, 127, b);
|
||||
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
|
||||
|
||||
Value incrValue = b.create<math::CopySignOp>(half, operand);
|
||||
Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
|
||||
|
||||
Value fpFixedConvert = createTruncatedFPValue(add, b);
|
||||
rewriter.replaceOp(op, fpFixedConvert);
|
||||
|
||||
// There are three cases where adding 0.5 to the value and truncating by
|
||||
// converting to an i64 does not result in the correct behavior:
|
||||
//
|
||||
// 1. Special values: +-inf and +-nan
|
||||
// Casting these special values to i64 has undefined behavior. To identify
|
||||
// these values, we use the fact that these values are the only float
|
||||
// values with the maximum possible biased exponent.
|
||||
//
|
||||
// 2. Large values: 2^23 <= |x| <= INT_64_MAX
|
||||
// Adding 0.5 to a float larger than or equal to 2^23 results in precision
|
||||
// errors that sometimes round the value up and sometimes round the value
|
||||
// down. For example:
|
||||
// 8388608.0 + 0.5 = 8388608.0
|
||||
// 8388609.0 + 0.5 = 8388610.0
|
||||
//
|
||||
// 3. Very large values: |x| > INT_64_MAX
|
||||
// Casting to i64 a value greater than the max i64 value will overflow the
|
||||
// i64 leading to wrong outputs.
|
||||
//
|
||||
// All three cases satisfy the property `biasedExp >= 23`.
|
||||
Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
|
||||
Value operandExp = b.create<arith::AndIOp>(
|
||||
b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
|
||||
Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
|
||||
Value isSpecialValOrLargeVal =
|
||||
b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
|
||||
|
||||
Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
|
||||
fpFixedConvert);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -141,9 +141,10 @@ func.func @floorf_func(%a: f64) -> f64 {
|
||||
// CHECK-DAG: [[CST_0:%.+]] = arith.constant -1.000
|
||||
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
|
||||
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
|
||||
// CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
|
||||
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf olt, [[ARG0]], [[CST]]
|
||||
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
|
||||
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
|
||||
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
|
||||
// CHECK-NEXT: return [[ADDF]]
|
||||
%ret = math.floor %a : f64
|
||||
return %ret : f64
|
||||
@ -158,9 +159,10 @@ func.func @ceilf_func(%a: f64) -> f64 {
|
||||
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000
|
||||
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
|
||||
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
|
||||
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[CVTF]]
|
||||
// CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
|
||||
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
|
||||
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
|
||||
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
|
||||
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
|
||||
// CHECK-NEXT: return [[ADDF]]
|
||||
%ret = math.ceil %a : f64
|
||||
return %ret : f64
|
||||
@ -193,19 +195,26 @@ func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @roundf_func
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
|
||||
func.func @roundf_func(%a: f64) -> f64 {
|
||||
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
|
||||
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 5.000000e-01
|
||||
// CHECK-DAG: [[CST_1:%.+]] = arith.constant -5.000000e-01
|
||||
// CHECK-DAG: [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]]
|
||||
// CHECK-DAG: [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]]
|
||||
// CHECK-DAG: [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]]
|
||||
// CHECK-DAG: [[CVTI:%.+]] = arith.fptosi [[ADDF]]
|
||||
// CHECK-DAG: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
|
||||
// CHECK: return [[CVTF]]
|
||||
%ret = math.round %a : f64
|
||||
return %ret : f64
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32
|
||||
func.func @roundf_func(%a: f32) -> f32 {
|
||||
// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
|
||||
// CHECK-DAG: %[[C23:.*]] = arith.constant 23
|
||||
// CHECK-DAG: %[[C127:.*]] = arith.constant 127
|
||||
// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 255
|
||||
// CHECK-DAG: %[[SHIFT:.*]] = math.copysign %[[HALF]], %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG_SHIFTED:.*]] = arith.addf %[[ARG0]], %[[SHIFT]]
|
||||
// CHECK-DAG: %[[FIXED_CONVERT:.*]] = arith.fptosi %[[ARG_SHIFTED]]
|
||||
// CHECK-DAG: %[[FP_FIXED_CONVERT_0:.*]] = arith.sitofp %[[FIXED_CONVERT]]
|
||||
// CHECK-DAG: %[[FP_FIXED_CONVERT_1:.*]] = math.copysign %[[FP_FIXED_CONVERT_0]], %[[ARG_SHIFTED]]
|
||||
// CHECK-DAG: %[[ARG_BITCAST:.*]] = arith.bitcast %[[ARG0]] : f32 to i32
|
||||
// CHECK-DAG: %[[ARG_BITCAST_SHIFTED:.*]] = arith.shrui %[[ARG_BITCAST]], %[[C23]]
|
||||
// CHECK-DAG: %[[ARG_EXP:.*]] = arith.andi %[[ARG_BITCAST_SHIFTED]], %[[EXP_MASK]]
|
||||
// CHECK-DAG: %[[ARG_BIASED_EXP:.*]] = arith.subi %[[ARG_EXP]], %[[C127]]
|
||||
// CHECK-DAG: %[[IS_SPECIAL_VAL:.*]] = arith.cmpi sge, %[[ARG_BIASED_EXP]], %[[C23]]
|
||||
// CHECK-DAG: %[[RESULT:.*]] = arith.select %[[IS_SPECIAL_VAL]], %[[ARG0]], %[[FP_FIXED_CONVERT_1]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
%ret = math.round %a : f32
|
||||
return %ret : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -19,37 +19,37 @@ func.func @exp2f() {
|
||||
%a = arith.constant 1.0 : f64
|
||||
call @func_exp2f(%a) : (f64) -> ()
|
||||
|
||||
// CHECK: 4
|
||||
// CHECK-NEXT: 4
|
||||
%b = arith.constant 2.0 : f64
|
||||
call @func_exp2f(%b) : (f64) -> ()
|
||||
|
||||
// CHECK: 5.65685
|
||||
// CHECK-NEXT: 5.65685
|
||||
%c = arith.constant 2.5 : f64
|
||||
call @func_exp2f(%c) : (f64) -> ()
|
||||
|
||||
// CHECK: 0.29730
|
||||
// CHECK-NEXT: 0.29730
|
||||
%d = arith.constant -1.75 : f64
|
||||
call @func_exp2f(%d) : (f64) -> ()
|
||||
|
||||
// CHECK: 1.09581
|
||||
// CHECK-NEXT: 1.09581
|
||||
%e = arith.constant 0.132 : f64
|
||||
call @func_exp2f(%e) : (f64) -> ()
|
||||
|
||||
// CHECK: inf
|
||||
// CHECK-NEXT: inf
|
||||
%f1 = arith.constant 0.00 : f64
|
||||
%f2 = arith.constant 1.00 : f64
|
||||
%f = arith.divf %f2, %f1 : f64
|
||||
call @func_exp2f(%f) : (f64) -> ()
|
||||
|
||||
// CHECK: inf
|
||||
// CHECK-NEXT: inf
|
||||
%g = arith.constant 5038939.0 : f64
|
||||
call @func_exp2f(%g) : (f64) -> ()
|
||||
|
||||
// CHECK: 0
|
||||
// CHECK-NEXT: 0
|
||||
%neg_inf = arith.constant 0xff80000000000000 : f64
|
||||
call @func_exp2f(%neg_inf) : (f64) -> ()
|
||||
|
||||
// CHECK: inf
|
||||
// CHECK-NEXT: inf
|
||||
%i = arith.constant 0x7fc0000000000000 : f64
|
||||
call @func_exp2f(%i) : (f64) -> ()
|
||||
return
|
||||
@ -64,39 +64,113 @@ func.func @func_roundf(%a : f32) {
|
||||
return
|
||||
}
|
||||
|
||||
func.func @func_roundf$bitcast_result_to_int(%a : f32) {
|
||||
%b = math.round %a : f32
|
||||
%c = arith.bitcast %b : f32 to i32
|
||||
vector.print %c : i32
|
||||
return
|
||||
}
|
||||
|
||||
func.func @func_roundf$vector(%a : vector<1xf32>) {
|
||||
%b = math.round %a : vector<1xf32>
|
||||
vector.print %b : vector<1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func.func @roundf() {
|
||||
// CHECK: 4
|
||||
// CHECK-NEXT: 4
|
||||
%a = arith.constant 3.8 : f32
|
||||
call @func_roundf(%a) : (f32) -> ()
|
||||
|
||||
// CHECK: -4
|
||||
// CHECK-NEXT: -4
|
||||
%b = arith.constant -3.8 : f32
|
||||
call @func_roundf(%b) : (f32) -> ()
|
||||
|
||||
// CHECK: 0
|
||||
%c = arith.constant 0.0 : f32
|
||||
// CHECK-NEXT: -4
|
||||
%c = arith.constant -4.2 : f32
|
||||
call @func_roundf(%c) : (f32) -> ()
|
||||
|
||||
// CHECK: -4
|
||||
%d = arith.constant -4.2 : f32
|
||||
// CHECK-NEXT: -495
|
||||
%d = arith.constant -495.0 : f32
|
||||
call @func_roundf(%d) : (f32) -> ()
|
||||
|
||||
// CHECK: -495
|
||||
%e = arith.constant -495.0 : f32
|
||||
// CHECK-NEXT: 495
|
||||
%e = arith.constant 495.0 : f32
|
||||
call @func_roundf(%e) : (f32) -> ()
|
||||
|
||||
// CHECK: 495
|
||||
%f = arith.constant 495.0 : f32
|
||||
// CHECK-NEXT: 9
|
||||
%f = arith.constant 8.5 : f32
|
||||
call @func_roundf(%f) : (f32) -> ()
|
||||
|
||||
// CHECK: 9
|
||||
%g = arith.constant 8.5 : f32
|
||||
// CHECK-NEXT: -9
|
||||
%g = arith.constant -8.5 : f32
|
||||
call @func_roundf(%g) : (f32) -> ()
|
||||
|
||||
// CHECK: -9
|
||||
%h = arith.constant -8.5 : f32
|
||||
// CHECK-NEXT: -0
|
||||
%h = arith.constant -0.4 : f32
|
||||
call @func_roundf(%h) : (f32) -> ()
|
||||
|
||||
// Special values: 0, -0, inf, -inf, nan, -nan
|
||||
%cNeg0 = arith.constant -0.0 : f32
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%cInfInt = arith.constant 0x7f800000 : i32
|
||||
%cInf = arith.bitcast %cInfInt : i32 to f32
|
||||
%cNegInfInt = arith.constant 0xff800000 : i32
|
||||
%cNegInf = arith.bitcast %cNegInfInt : i32 to f32
|
||||
%cNanInt = arith.constant 0x7fc00000 : i32
|
||||
%cNan = arith.bitcast %cNanInt : i32 to f32
|
||||
%cNegNanInt = arith.constant 0xffc00000 : i32
|
||||
%cNegNan = arith.bitcast %cNegNanInt : i32 to f32
|
||||
|
||||
// CHECK-NEXT: -0
|
||||
call @func_roundf(%cNeg0) : (f32) -> ()
|
||||
// CHECK-NEXT: 0
|
||||
call @func_roundf(%c0) : (f32) -> ()
|
||||
// CHECK-NEXT: inf
|
||||
call @func_roundf(%cInf) : (f32) -> ()
|
||||
// CHECK-NEXT: -inf
|
||||
call @func_roundf(%cNegInf) : (f32) -> ()
|
||||
// CHECK-NEXT: nan
|
||||
call @func_roundf(%cNan) : (f32) -> ()
|
||||
// CHECK-NEXT: -nan
|
||||
call @func_roundf(%cNegNan) : (f32) -> ()
|
||||
|
||||
// Very large values (greater than INT_64_MAX)
|
||||
%c2To100 = arith.constant 1.268e30 : f32 // 2^100
|
||||
// CHECK-NEXT: 1.268e+30
|
||||
call @func_roundf(%c2To100) : (f32) -> ()
|
||||
|
||||
// Values above and below 2^23 = 8388608
|
||||
%c8388606_5 = arith.constant 8388606.5 : f32
|
||||
%c8388607 = arith.constant 8388607.0 : f32
|
||||
%c8388607_5 = arith.constant 8388607.5 : f32
|
||||
%c8388608 = arith.constant 8388608.0 : f32
|
||||
%c8388609 = arith.constant 8388609.0 : f32
|
||||
|
||||
// Bitcast result to int to avoid printing in scientific notation,
|
||||
// which does not display all significant digits.
|
||||
|
||||
// CHECK-NEXT: 1258291198
|
||||
// hex: 0x4AFFFFFE
|
||||
call @func_roundf$bitcast_result_to_int(%c8388606_5) : (f32) -> ()
|
||||
// CHECK-NEXT: 1258291198
|
||||
// hex: 0x4AFFFFFE
|
||||
call @func_roundf$bitcast_result_to_int(%c8388607) : (f32) -> ()
|
||||
// CHECK-NEXT: 1258291200
|
||||
// hex: 0x4B000000
|
||||
call @func_roundf$bitcast_result_to_int(%c8388607_5) : (f32) -> ()
|
||||
// CHECK-NEXT: 1258291200
|
||||
// hex: 0x4B000000
|
||||
call @func_roundf$bitcast_result_to_int(%c8388608) : (f32) -> ()
|
||||
// CHECK-NEXT: 1258291201
|
||||
// hex: 0x4B000001
|
||||
call @func_roundf$bitcast_result_to_int(%c8388609) : (f32) -> ()
|
||||
|
||||
// Check that vector type works
|
||||
%cVec = arith.constant dense<[0.5]> : vector<1xf32>
|
||||
// CHECK-NEXT: ( 1 )
|
||||
call @func_roundf$vector(%cVec) : (vector<1xf32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -110,52 +184,52 @@ func.func @func_powff64(%a : f64, %b : f64) {
|
||||
}
|
||||
|
||||
func.func @powf() {
|
||||
// CHECK: 16
|
||||
// CHECK-NEXT: 16
|
||||
%a = arith.constant 4.0 : f64
|
||||
%a_p = arith.constant 2.0 : f64
|
||||
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: nan
|
||||
// CHECK-NEXT: nan
|
||||
%b = arith.constant -3.0 : f64
|
||||
%b_p = arith.constant 3.0 : f64
|
||||
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: 2.343
|
||||
// CHECK-NEXT: 2.343
|
||||
%c = arith.constant 2.343 : f64
|
||||
%c_p = arith.constant 1.000 : f64
|
||||
call @func_powff64(%c, %c_p) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: 0.176171
|
||||
// CHECK-NEXT: 0.176171
|
||||
%d = arith.constant 4.25 : f64
|
||||
%d_p = arith.constant -1.2 : f64
|
||||
call @func_powff64(%d, %d_p) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: 1
|
||||
// CHECK-NEXT: 1
|
||||
%e = arith.constant 4.385 : f64
|
||||
%e_p = arith.constant 0.00 : f64
|
||||
call @func_powff64(%e, %e_p) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: 6.62637
|
||||
// CHECK-NEXT: 6.62637
|
||||
%f = arith.constant 4.835 : f64
|
||||
%f_p = arith.constant 1.2 : f64
|
||||
call @func_powff64(%f, %f_p) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: nan
|
||||
// CHECK-NEXT: nan
|
||||
%g = arith.constant 0xff80000000000000 : f64
|
||||
call @func_powff64(%g, %g) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: nan
|
||||
// CHECK-NEXT: nan
|
||||
%h = arith.constant 0x7fffffffffffffff : f64
|
||||
call @func_powff64(%h, %h) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: nan
|
||||
// CHECK-NEXT: nan
|
||||
%i = arith.constant 1.0 : f64
|
||||
call @func_powff64(%i, %h) : (f64, f64) -> ()
|
||||
|
||||
// CHECK: inf
|
||||
// CHECK-NEXT: inf
|
||||
%j = arith.constant 29385.0 : f64
|
||||
%j_p = arith.constant 23598.0 : f64
|
||||
call @func_powff64(%j, %j_p) : (f64, f64) -> ()
|
||||
call @func_powff64(%j, %j_p) : (f64, f64) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user