[mlir][arith] Add narrowing patterns for addi and muli

These two ops are handled in a very similar way -- the only difference
in the number result bits produced.

I checked these transformation with Alive2:
1.  addi + sext: https://alive2.llvm.org/ce/z/3NSs9T
2.  addi + zext: https://alive2.llvm.org/ce/z/t7XHOT
3.  muli + sext: https://alive2.llvm.org/ce/z/-7sfW9
4.  muli + zext: https://alive2.llvm.org/ce/z/h4yntF

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149530
This commit is contained in:
Jakub Kuderski 2023-05-02 10:09:51 -04:00
parent 55678b43b5
commit e2f7563d7c
2 changed files with 272 additions and 2 deletions

View File

@ -216,6 +216,93 @@ FailureOr<unsigned> calculateBitsRequired(Value value,
return calculateBitsRequired(value.getType());
}
/// Base pattern for arith binary ops.
/// Example:
/// ```
/// %lhs = arith.extsi %a : i8 to i32
/// %rhs = arith.extsi %b : i8 to i32
/// %r = arith.addi %lhs, %rhs : i32
/// ==>
/// %lhs = arith.extsi %a : i8 to i16
/// %rhs = arith.extsi %b : i8 to i16
/// %add = arith.addi %lhs, %rhs : i16
/// %r = arith.extsi %add : i16 to i32
/// ```
template <typename BinaryOp>
struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
using NarrowingPattern<BinaryOp>::NarrowingPattern;
/// Returns the number of bits required to represent the full result, assuming
/// that both operands are `operandBits`-wide. Derived classes must implement
/// this, taking into account `BinaryOp` semantics.
virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
LogicalResult matchAndRewrite(BinaryOp op,
PatternRewriter &rewriter) const final {
Type origTy = op.getType();
FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
if (failed(resultBits))
return failure();
// For the optimization to apply, we expect the lhs to be an extension op,
// and for the rhs to either be the same extension op or a constant.
FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
if (failed(ext))
return failure();
FailureOr<unsigned> lhsBitsRequired =
calculateBitsRequired(ext->getIn(), ext->getKind());
if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
return failure();
FailureOr<unsigned> rhsBitsRequired =
calculateBitsRequired(op.getRhs(), ext->getKind());
if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
return failure();
// Negotiate a common bit requirements for both lhs and rhs, accounting for
// the result requiring more bits than the operands.
unsigned commonBitsRequired =
getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
return failure();
Location loc = op.getLoc();
Value newLhs =
rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
Value newRhs =
rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
ext->recreateAndReplace(rewriter, op, newAdd);
return success();
}
};
//===----------------------------------------------------------------------===//
// AddIOp Pattern
//===----------------------------------------------------------------------===//
struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
unsigned getResultBitsProduced(unsigned operandBits) const override {
return operandBits + 1;
}
};
//===----------------------------------------------------------------------===//
// MulIOp Pattern
//===----------------------------------------------------------------------===//
struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
unsigned getResultBitsProduced(unsigned operandBits) const override {
return 2 * operandBits;
}
};
//===----------------------------------------------------------------------===//
// *IToFPOp Patterns
//===----------------------------------------------------------------------===//
@ -538,7 +625,8 @@ void populateArithIntNarrowingPatterns(
ExtensionOverTranspose, ExtensionOverFlatTranspose>(
patterns.getContext(), options, PatternBenefit(2));
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
patterns.add<AddIPattern, MulIPattern, SIToFPPattern, UIToFPPattern>(
patterns.getContext(), options);
}
} // namespace mlir::arith

View File

@ -1,6 +1,188 @@
// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,24,32" \
// RUN: --verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// arith.addi
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @addi_extsi_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i16 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extsi %rhs : i8 to i32
%r = arith.addi %a, %b : i32
return %r : i32
}
// CHECK-LABEL: func.func @addi_extui_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i16 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extui %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
%r = arith.addi %a, %b : i32
return %r : i32
}
// arith.addi produces one more bit of result than the operand bitwidth.
//
// CHECK-LABEL: func.func @addi_extsi_i24
// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i24
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @addi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
%a = arith.extsi %lhs : i16 to i32
%b = arith.extsi %rhs : i16 to i32
%r = arith.addi %a, %b : i32
return %r : i32
}
// This case should not get optimized because of mixed extensions.
//
// CHECK-LABEL: func.func @addi_mixed_ext_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32
// CHECK-NEXT: return %[[ADD]] : i32
func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
%r = arith.addi %a, %b : i32
return %r : i32
}
// This case should not get optimized because we cannot reduce the bitwidth
// below i16, given the pass options set.
//
// CHECK-LABEL: func.func @addi_extsi_i16
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i16
// CHECK-NEXT: return %[[ADD]] : i16
func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
%a = arith.extsi %lhs : i8 to i16
%b = arith.extsi %rhs : i8 to i16
%r = arith.addi %a, %b : i16
return %r : i16
}
// CHECK-LABEL: func.func @addi_extsi_3xi8_cst
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>)
// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[CST]] : vector<3xi16>
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
%cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
%a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
%r = arith.addi %a, %cst : vector<3xi32>
return %r : vector<3xi32>
}
//===----------------------------------------------------------------------===//
// arith.muli
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @muli_extsi_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : i16 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @muli_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extsi %rhs : i8 to i32
%r = arith.muli %a, %b : i32
return %r : i32
}
// CHECK-LABEL: func.func @muli_extui_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i16 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extui %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
%r = arith.muli %a, %b : i32
return %r : i32
}
// We do not expect this case to be optimized because given n-bit operands,
// arith.muli produces 2n bits of result.
//
// CHECK-LABEL: func.func @muli_extsi_i32
// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
// CHECK-NEXT: %[[LHS:.+]] = arith.extsi %[[ARG0]] : i16 to i32
// CHECK-NEXT: %[[RHS:.+]] = arith.extsi %[[ARG1]] : i16 to i32
// CHECK-NEXT: %[[RET:.+]] = arith.muli %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
%a = arith.extsi %lhs : i16 to i32
%b = arith.extsi %rhs : i16 to i32
%r = arith.muli %a, %b : i32
return %r : i32
}
// This case should not get optimized because of mixed extensions.
//
// CHECK-LABEL: func.func @muli_mixed_ext_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32
// CHECK-NEXT: return %[[MUL]] : i32
func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
%r = arith.muli %a, %b : i32
return %r : i32
}
// CHECK-LABEL: func.func @muli_extsi_3xi8_cst
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>)
// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[CST]] : vector<3xi16>
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : vector<3xi16> to vector<3xi32>
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
%cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
%a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
%r = arith.muli %a, %cst : vector<3xi32>
return %r : vector<3xi32>
}
//===----------------------------------------------------------------------===//
// arith.*itofp
//===----------------------------------------------------------------------===//