mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-24 06:10:12 +00:00
[mlir][arith] Add narrowing patterns for addi
and muli
These two ops are handled in a very similar way -- the only difference in the number result bits produced. I checked these transformation with Alive2: 1. addi + sext: https://alive2.llvm.org/ce/z/3NSs9T 2. addi + zext: https://alive2.llvm.org/ce/z/t7XHOT 3. muli + sext: https://alive2.llvm.org/ce/z/-7sfW9 4. muli + zext: https://alive2.llvm.org/ce/z/h4yntF Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D149530
This commit is contained in:
parent
55678b43b5
commit
e2f7563d7c
@ -216,6 +216,93 @@ FailureOr<unsigned> calculateBitsRequired(Value value,
|
||||
return calculateBitsRequired(value.getType());
|
||||
}
|
||||
|
||||
/// Base pattern for arith binary ops.
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %lhs = arith.extsi %a : i8 to i32
|
||||
/// %rhs = arith.extsi %b : i8 to i32
|
||||
/// %r = arith.addi %lhs, %rhs : i32
|
||||
/// ==>
|
||||
/// %lhs = arith.extsi %a : i8 to i16
|
||||
/// %rhs = arith.extsi %b : i8 to i16
|
||||
/// %add = arith.addi %lhs, %rhs : i16
|
||||
/// %r = arith.extsi %add : i16 to i32
|
||||
/// ```
|
||||
template <typename BinaryOp>
|
||||
struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
|
||||
using NarrowingPattern<BinaryOp>::NarrowingPattern;
|
||||
|
||||
/// Returns the number of bits required to represent the full result, assuming
|
||||
/// that both operands are `operandBits`-wide. Derived classes must implement
|
||||
/// this, taking into account `BinaryOp` semantics.
|
||||
virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
|
||||
|
||||
LogicalResult matchAndRewrite(BinaryOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Type origTy = op.getType();
|
||||
FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
|
||||
if (failed(resultBits))
|
||||
return failure();
|
||||
|
||||
// For the optimization to apply, we expect the lhs to be an extension op,
|
||||
// and for the rhs to either be the same extension op or a constant.
|
||||
FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
|
||||
if (failed(ext))
|
||||
return failure();
|
||||
|
||||
FailureOr<unsigned> lhsBitsRequired =
|
||||
calculateBitsRequired(ext->getIn(), ext->getKind());
|
||||
if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
|
||||
return failure();
|
||||
|
||||
FailureOr<unsigned> rhsBitsRequired =
|
||||
calculateBitsRequired(op.getRhs(), ext->getKind());
|
||||
if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
|
||||
return failure();
|
||||
|
||||
// Negotiate a common bit requirements for both lhs and rhs, accounting for
|
||||
// the result requiring more bits than the operands.
|
||||
unsigned commonBitsRequired =
|
||||
getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
|
||||
FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
|
||||
if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value newLhs =
|
||||
rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
|
||||
Value newRhs =
|
||||
rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
|
||||
Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
|
||||
ext->recreateAndReplace(rewriter, op, newAdd);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddIOp Pattern
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
|
||||
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
|
||||
|
||||
unsigned getResultBitsProduced(unsigned operandBits) const override {
|
||||
return operandBits + 1;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulIOp Pattern
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
|
||||
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
|
||||
|
||||
unsigned getResultBitsProduced(unsigned operandBits) const override {
|
||||
return 2 * operandBits;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// *IToFPOp Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -538,7 +625,8 @@ void populateArithIntNarrowingPatterns(
|
||||
ExtensionOverTranspose, ExtensionOverFlatTranspose>(
|
||||
patterns.getContext(), options, PatternBenefit(2));
|
||||
|
||||
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
|
||||
patterns.add<AddIPattern, MulIPattern, SIToFPPattern, UIToFPPattern>(
|
||||
patterns.getContext(), options);
|
||||
}
|
||||
|
||||
} // namespace mlir::arith
|
||||
|
@ -1,6 +1,188 @@
|
||||
// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
|
||||
// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,24,32" \
|
||||
// RUN: --verify-diagnostics %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// arith.addi
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func.func @addi_extsi_i8
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
|
||||
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
|
||||
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
|
||||
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
|
||||
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
|
||||
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
|
||||
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i16 to i32
|
||||
// CHECK-NEXT: return %[[RET]] : i32
|
||||
func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
|
||||
%a = arith.extsi %lhs : i8 to i32
|
||||
%b = arith.extsi %rhs : i8 to i32
|
||||
%r = arith.addi %a, %b : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @addi_extui_i8
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
|
||||
// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
|
||||
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
|
||||
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
|
||||
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
|
||||
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
|
||||
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i16 to i32
|
||||
// CHECK-NEXT: return %[[RET]] : i32
|
||||
func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
|
||||
%a = arith.extui %lhs : i8 to i32
|
||||
%b = arith.extui %rhs : i8 to i32
|
||||
%r = arith.addi %a, %b : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// arith.addi produces one more bit of result than the operand bitwidth.
|
||||
//
|
||||
// CHECK-LABEL: func.func @addi_extsi_i24
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
|
||||
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
|
||||
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
|
||||
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24
|
||||
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24
|
||||
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i24
|
||||
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32
|
||||
// CHECK-NEXT: return %[[RET]] : i32
|
||||
func.func @addi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
|
||||
%a = arith.extsi %lhs : i16 to i32
|
||||
%b = arith.extsi %rhs : i16 to i32
|
||||
%r = arith.addi %a, %b : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// This case should not get optimized because of mixed extensions.
|
||||
//
|
||||
// CHECK-LABEL: func.func @addi_mixed_ext_i8
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
|
||||
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
|
||||
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
|
||||
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32
|
||||
// CHECK-NEXT: return %[[ADD]] : i32
|
||||
func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
|
||||
%a = arith.extsi %lhs : i8 to i32
|
||||
%b = arith.extui %rhs : i8 to i32
|
||||
%r = arith.addi %a, %b : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// This case should not get optimized because we cannot reduce the bitwidth
|
||||
// below i16, given the pass options set.
|
||||
//
|
||||
// CHECK-LABEL: func.func @addi_extsi_i16
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
|
||||
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16
|
||||
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16
|
||||
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i16
|
||||
// CHECK-NEXT: return %[[ADD]] : i16
|
||||
func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
|
||||
%a = arith.extsi %lhs : i8 to i16
|
||||
%b = arith.extsi %rhs : i8 to i16
|
||||
%r = arith.addi %a, %b : i16
|
||||
return %r : i16
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @addi_extsi_3xi8_cst
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>)
|
||||
// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
|
||||
// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
|
||||
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
|
||||
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[CST]] : vector<3xi16>
|
||||
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : vector<3xi16> to vector<3xi32>
|
||||
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
|
||||
func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
|
||||
%cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
|
||||
%a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
|
||||
%r = arith.addi %a, %cst : vector<3xi32>
|
||||
return %r : vector<3xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// arith.muli
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func.func @muli_extsi_i8
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
|
||||
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
|
||||
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
|
||||
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
|
||||
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
|
||||
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16
|
||||
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : i16 to i32
|
||||
// CHECK-NEXT: return %[[RET]] : i32
|
||||
func.func @muli_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
|
||||
%a = arith.extsi %lhs : i8 to i32
|
||||
%b = arith.extsi %rhs : i8 to i32
|
||||
%r = arith.muli %a, %b : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @muli_extui_i8
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
|
||||
// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
|
||||
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
|
||||
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
|
||||
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
|
||||
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16
|
||||
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i16 to i32
|
||||
// CHECK-NEXT: return %[[RET]] : i32
|
||||
func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
|
||||
%a = arith.extui %lhs : i8 to i32
|
||||
%b = arith.extui %rhs : i8 to i32
|
||||
%r = arith.muli %a, %b : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// We do not expect this case to be optimized because given n-bit operands,
|
||||
// arith.muli produces 2n bits of result.
|
||||
//
|
||||
// CHECK-LABEL: func.func @muli_extsi_i32
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
|
||||
// CHECK-NEXT: %[[LHS:.+]] = arith.extsi %[[ARG0]] : i16 to i32
|
||||
// CHECK-NEXT: %[[RHS:.+]] = arith.extsi %[[ARG1]] : i16 to i32
|
||||
// CHECK-NEXT: %[[RET:.+]] = arith.muli %[[LHS]], %[[RHS]] : i32
|
||||
// CHECK-NEXT: return %[[RET]] : i32
|
||||
func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
|
||||
%a = arith.extsi %lhs : i16 to i32
|
||||
%b = arith.extsi %rhs : i16 to i32
|
||||
%r = arith.muli %a, %b : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// This case should not get optimized because of mixed extensions.
|
||||
//
|
||||
// CHECK-LABEL: func.func @muli_mixed_ext_i8
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
|
||||
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
|
||||
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
|
||||
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32
|
||||
// CHECK-NEXT: return %[[MUL]] : i32
|
||||
func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
|
||||
%a = arith.extsi %lhs : i8 to i32
|
||||
%b = arith.extui %rhs : i8 to i32
|
||||
%r = arith.muli %a, %b : i32
|
||||
return %r : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @muli_extsi_3xi8_cst
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>)
|
||||
// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
|
||||
// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
|
||||
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
|
||||
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[CST]] : vector<3xi16>
|
||||
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : vector<3xi16> to vector<3xi32>
|
||||
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
|
||||
func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
|
||||
%cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
|
||||
%a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
|
||||
%r = arith.muli %a, %cst : vector<3xi32>
|
||||
return %r : vector<3xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// arith.*itofp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user