[mlir] Math: add algebraic simplification patterns to math transforms

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D106822
This commit is contained in:
Eugene Zhulenev 2021-07-27 09:17:31 -07:00
parent 9b1bcaea4e
commit d94426d22a
7 changed files with 219 additions and 0 deletions

View File

@ -15,6 +15,8 @@ class RewritePatternSet;
void populateExpandTanhPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns);
} // namespace mlir

View File

@ -0,0 +1,112 @@
//===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewrites based on the basic rules of algebra
// (Commutativity, associativity, etc...) and strength reductions for math
// operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include <climits>
using namespace mlir;
//----------------------------------------------------------------------------//
// PowFOp strength reduction.
//----------------------------------------------------------------------------//
namespace {
struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::PowFOp op,
PatternRewriter &rewriter) const final;
};
} // namespace
LogicalResult
PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value x = op.lhs();
FloatAttr scalarExponent;
DenseFPElementsAttr vectorExponent;
bool isScalar = matchPattern(op.rhs(), m_Constant(&scalarExponent));
bool isVector = matchPattern(op.rhs(), m_Constant(&vectorExponent));
// Returns true if exponent is a constant equal to `value`.
auto isExponentValue = [&](double value) -> bool {
if (isScalar)
return scalarExponent.getValue().isExactlyValue(value);
if (isVector && vectorExponent.isSplat())
return vectorExponent.getSplatValue<FloatAttr>()
.getValue()
.isExactlyValue(value);
return false;
};
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
if (auto vec = op.getType().dyn_cast<VectorType>())
return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
return value;
};
// Replace `pow(x, 1.0)` with `x`.
if (isExponentValue(1.0)) {
rewriter.replaceOp(op, x);
return success();
}
// Replace `pow(x, 2.0)` with `x * x`.
if (isExponentValue(2.0)) {
rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, x}));
return success();
}
// Replace `pow(x, 2.0)` with `x * x * x`.
if (isExponentValue(3.0)) {
Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x}));
rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square}));
return success();
}
// Replace `pow(x, -1.0)` with `1.0 / x`.
if (isExponentValue(-1.0)) {
Value one = rewriter.create<ConstantOp>(
loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
rewriter.replaceOpWithNewOp<DivFOp>(op, ValueRange({bcast(one), x}));
return success();
}
// Replace `pow(x, -2.0)` with `sqrt(x)`.
if (isExponentValue(-1.0)) {
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
return success();
}
return failure();
}
//----------------------------------------------------------------------------//
void mlir::populateMathAlgebraicSimplificationPatterns(
RewritePatternSet &patterns) {
patterns.add<PowFStrengthReduction>(patterns.getContext());
}

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandTanh.cpp
PolynomialApproximation.cpp

View File

@ -0,0 +1,51 @@
// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s --dump-input=always
// CHECK-LABEL: @pow_noop
func @pow_noop(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK: return %arg0, %arg1
%c = constant 1.0 : f32
%v = constant dense <1.0> : vector<4xf32>
%0 = math.powf %arg0, %c : f32
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}
// CHECK-LABEL: @pow_square
func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK: %[[SCALAR:.*]] = mulf %arg0, %arg0
// CHECK: %[[VECTOR:.*]] = mulf %arg1, %arg1
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = constant 2.0 : f32
%v = constant dense <2.0> : vector<4xf32>
%0 = math.powf %arg0, %c : f32
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}
// CHECK-LABEL: @pow_cube
func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK: %[[TMP_S:.*]] = mulf %arg0, %arg0
// CHECK: %[[SCALAR:.*]] = mulf %arg0, %[[TMP_S]]
// CHECK: %[[TMP_V:.*]] = mulf %arg1, %arg1
// CHECK: %[[VECTOR:.*]] = mulf %arg1, %[[TMP_V]]
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = constant 3.0 : f32
%v = constant dense <3.0> : vector<4xf32>
%0 = math.powf %arg0, %c : f32
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}
// CHECK-LABEL: @pow_recip
func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK: %[[CST_S:.*]] = constant 1.0{{.*}} : f32
// CHECK: %[[CST_V:.*]] = constant dense<1.0{{.*}}> : vector<4xf32>
// CHECK: %[[SCALAR:.*]] = divf %[[CST_S]], %arg0
// CHECK: %[[VECTOR:.*]] = divf %[[CST_V]], %arg1
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = constant -1.0 : f32
%v = constant dense <-1.0> : vector<4xf32>
%0 = math.powf %arg0, %c : f32
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}

View File

@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMathTestPasses
TestAlgebraicSimplification.cpp
TestExpandTanh.cpp
TestPolynomialApproximation.cpp

View File

@ -0,0 +1,50 @@
//===- TestAlgebraicSimplification.cpp - Test algebraic simplification ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains test passes for algebraic simplification patterns.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestMathAlgebraicSimplificationPass
: public PassWrapper<TestMathAlgebraicSimplificationPass, FunctionPass> {
void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect, math::MathDialect>();
}
StringRef getArgument() const final {
return "test-math-algebraic-simplification";
}
StringRef getDescription() const final {
return "Test math algebraic simplification";
}
};
} // end anonymous namespace
void TestMathAlgebraicSimplificationPass::runOnFunction() {
RewritePatternSet patterns(&getContext());
populateMathAlgebraicSimplificationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
namespace mlir {
namespace test {
void registerTestMathAlgebraicSimplificationPass() {
PassRegistration<TestMathAlgebraicSimplificationPass>();
}
} // namespace test
} // namespace mlir

View File

@ -92,6 +92,7 @@ void registerTestLivenessPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
@ -173,6 +174,7 @@ void registerTestPasses() {
test::registerTestLoopFusion();
test::registerTestLoopMappingPass();
test::registerTestLoopUnrollingPass();
test::registerTestMathAlgebraicSimplificationPass();
test::registerTestMathPolynomialApproximationPass();
test::registerTestMemRefDependenceCheck();
test::registerTestMemRefStrideCalculation();