mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-23 13:50:11 +00:00
[mlir] Math: add algebraic simplification patterns to math transforms
Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D106822
This commit is contained in:
parent
9b1bcaea4e
commit
d94426d22a
@ -15,6 +15,8 @@ class RewritePatternSet;
|
||||
|
||||
void populateExpandTanhPattern(RewritePatternSet &patterns);
|
||||
|
||||
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
112
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
Normal file
112
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
Normal file
@ -0,0 +1,112 @@
|
||||
//===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements rewrites based on the basic rules of algebra
|
||||
// (Commutativity, associativity, etc...) and strength reductions for math
|
||||
// operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include <climits>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// PowFOp strength reduction.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
namespace {
|
||||
struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(math::PowFOp op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value x = op.lhs();
|
||||
|
||||
FloatAttr scalarExponent;
|
||||
DenseFPElementsAttr vectorExponent;
|
||||
|
||||
bool isScalar = matchPattern(op.rhs(), m_Constant(&scalarExponent));
|
||||
bool isVector = matchPattern(op.rhs(), m_Constant(&vectorExponent));
|
||||
|
||||
// Returns true if exponent is a constant equal to `value`.
|
||||
auto isExponentValue = [&](double value) -> bool {
|
||||
if (isScalar)
|
||||
return scalarExponent.getValue().isExactlyValue(value);
|
||||
|
||||
if (isVector && vectorExponent.isSplat())
|
||||
return vectorExponent.getSplatValue<FloatAttr>()
|
||||
.getValue()
|
||||
.isExactlyValue(value);
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// Maybe broadcasts scalar value into vector type compatible with `op`.
|
||||
auto bcast = [&](Value value) -> Value {
|
||||
if (auto vec = op.getType().dyn_cast<VectorType>())
|
||||
return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
|
||||
return value;
|
||||
};
|
||||
|
||||
// Replace `pow(x, 1.0)` with `x`.
|
||||
if (isExponentValue(1.0)) {
|
||||
rewriter.replaceOp(op, x);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Replace `pow(x, 2.0)` with `x * x`.
|
||||
if (isExponentValue(2.0)) {
|
||||
rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, x}));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Replace `pow(x, 2.0)` with `x * x * x`.
|
||||
if (isExponentValue(3.0)) {
|
||||
Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x}));
|
||||
rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square}));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Replace `pow(x, -1.0)` with `1.0 / x`.
|
||||
if (isExponentValue(-1.0)) {
|
||||
Value one = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
|
||||
rewriter.replaceOpWithNewOp<DivFOp>(op, ValueRange({bcast(one), x}));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Replace `pow(x, -2.0)` with `sqrt(x)`.
|
||||
if (isExponentValue(-1.0)) {
|
||||
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
void mlir::populateMathAlgebraicSimplificationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<PowFStrengthReduction>(patterns.getContext());
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
add_mlir_dialect_library(MLIRMathTransforms
|
||||
AlgebraicSimplification.cpp
|
||||
ExpandTanh.cpp
|
||||
PolynomialApproximation.cpp
|
||||
|
||||
|
51
mlir/test/Dialect/Math/algebraic-simplification.mlir
Normal file
51
mlir/test/Dialect/Math/algebraic-simplification.mlir
Normal file
@ -0,0 +1,51 @@
|
||||
// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: @pow_noop
|
||||
func @pow_noop(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
|
||||
// CHECK: return %arg0, %arg1
|
||||
%c = constant 1.0 : f32
|
||||
%v = constant dense <1.0> : vector<4xf32>
|
||||
%0 = math.powf %arg0, %c : f32
|
||||
%1 = math.powf %arg1, %v : vector<4xf32>
|
||||
return %0, %1 : f32, vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @pow_square
|
||||
func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
|
||||
// CHECK: %[[SCALAR:.*]] = mulf %arg0, %arg0
|
||||
// CHECK: %[[VECTOR:.*]] = mulf %arg1, %arg1
|
||||
// CHECK: return %[[SCALAR]], %[[VECTOR]]
|
||||
%c = constant 2.0 : f32
|
||||
%v = constant dense <2.0> : vector<4xf32>
|
||||
%0 = math.powf %arg0, %c : f32
|
||||
%1 = math.powf %arg1, %v : vector<4xf32>
|
||||
return %0, %1 : f32, vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @pow_cube
|
||||
func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
|
||||
// CHECK: %[[TMP_S:.*]] = mulf %arg0, %arg0
|
||||
// CHECK: %[[SCALAR:.*]] = mulf %arg0, %[[TMP_S]]
|
||||
// CHECK: %[[TMP_V:.*]] = mulf %arg1, %arg1
|
||||
// CHECK: %[[VECTOR:.*]] = mulf %arg1, %[[TMP_V]]
|
||||
// CHECK: return %[[SCALAR]], %[[VECTOR]]
|
||||
%c = constant 3.0 : f32
|
||||
%v = constant dense <3.0> : vector<4xf32>
|
||||
%0 = math.powf %arg0, %c : f32
|
||||
%1 = math.powf %arg1, %v : vector<4xf32>
|
||||
return %0, %1 : f32, vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @pow_recip
|
||||
func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
|
||||
// CHECK: %[[CST_S:.*]] = constant 1.0{{.*}} : f32
|
||||
// CHECK: %[[CST_V:.*]] = constant dense<1.0{{.*}}> : vector<4xf32>
|
||||
// CHECK: %[[SCALAR:.*]] = divf %[[CST_S]], %arg0
|
||||
// CHECK: %[[VECTOR:.*]] = divf %[[CST_V]], %arg1
|
||||
// CHECK: return %[[SCALAR]], %[[VECTOR]]
|
||||
%c = constant -1.0 : f32
|
||||
%v = constant dense <-1.0> : vector<4xf32>
|
||||
%0 = math.powf %arg0, %c : f32
|
||||
%1 = math.powf %arg1, %v : vector<4xf32>
|
||||
return %0, %1 : f32, vector<4xf32>
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRMathTestPasses
|
||||
TestAlgebraicSimplification.cpp
|
||||
TestExpandTanh.cpp
|
||||
TestPolynomialApproximation.cpp
|
||||
|
||||
|
50
mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp
Normal file
50
mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp
Normal file
@ -0,0 +1,50 @@
|
||||
//===- TestAlgebraicSimplification.cpp - Test algebraic simplification ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains test passes for algebraic simplification patterns.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestMathAlgebraicSimplificationPass
|
||||
: public PassWrapper<TestMathAlgebraicSimplificationPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<vector::VectorDialect, math::MathDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-math-algebraic-simplification";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test math algebraic simplification";
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void TestMathAlgebraicSimplificationPass::runOnFunction() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateMathAlgebraicSimplificationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestMathAlgebraicSimplificationPass() {
|
||||
PassRegistration<TestMathAlgebraicSimplificationPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
@ -92,6 +92,7 @@ void registerTestLivenessPass();
|
||||
void registerTestLoopFusion();
|
||||
void registerTestLoopMappingPass();
|
||||
void registerTestLoopUnrollingPass();
|
||||
void registerTestMathAlgebraicSimplificationPass();
|
||||
void registerTestMathPolynomialApproximationPass();
|
||||
void registerTestMemRefDependenceCheck();
|
||||
void registerTestMemRefStrideCalculation();
|
||||
@ -173,6 +174,7 @@ void registerTestPasses() {
|
||||
test::registerTestLoopFusion();
|
||||
test::registerTestLoopMappingPass();
|
||||
test::registerTestLoopUnrollingPass();
|
||||
test::registerTestMathAlgebraicSimplificationPass();
|
||||
test::registerTestMathPolynomialApproximationPass();
|
||||
test::registerTestMemRefDependenceCheck();
|
||||
test::registerTestMemRefStrideCalculation();
|
||||
|
Loading…
Reference in New Issue
Block a user