mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-15 12:09:51 +00:00
[mlir] Allow data flow analysis of non-control flow branch arguments
This commit adds the visitNonControlFlowArguments method to DataFlowAnalysis, allowing analyses to provide lattice values for the arguments to a RegionSuccessor block that aren't directly tied to an op's inputs. For example, integer range interface can use this method to infer bounds for the step values in loops. This method has a default implementation that keeps the old behavior of assigning a pessimistic fixedpoint state to all such arguments. Reviewed By: Mogball, rriddle Differential Revision: https://reviews.llvm.org/D124021
This commit is contained in:
parent
7164c5f051
commit
d35f7f254f
@ -250,6 +250,15 @@ public:
|
||||
ArrayRef<AbstractLatticeElement *> operands,
|
||||
SmallVectorImpl<RegionSuccessor> &successors) = 0;
|
||||
|
||||
/// Given a operation with successor regions, one of those regions,
|
||||
/// and the lattice elements corresponding to the operation's
|
||||
/// arguments, compute the latice values for block arguments
|
||||
/// that are not accounted for by the branching control flow (ex. the
|
||||
/// bounds of loops).
|
||||
virtual ChangeResult
|
||||
visitNonControlFlowArguments(Operation *op, const RegionSuccessor ®ion,
|
||||
ArrayRef<AbstractLatticeElement *> operands) = 0;
|
||||
|
||||
/// Create a new uninitialized lattice element. An optional value is provided
|
||||
/// which, if valid, should be used to initialize the known conservative state
|
||||
/// of the lattice.
|
||||
@ -347,6 +356,33 @@ protected:
|
||||
branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
|
||||
}
|
||||
|
||||
/// Given a operation with successor regions, one of those regions,
|
||||
/// and the lattice elements corresponding to the operation's
|
||||
/// arguments, compute the latice values for block arguments
|
||||
/// that are not accounted for by the branching control flow (ex. the
|
||||
/// bounds of loops). By default, this method marks all such lattice elements
|
||||
/// as having reached a pessimistic fixpoint. The region in the
|
||||
/// RegionSuccessor and the operand latice elements are guaranteed to be
|
||||
/// non-null.
|
||||
virtual ChangeResult
|
||||
visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
|
||||
ArrayRef<LatticeElement<ValueT> *> operands) {
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
Region *region = successor.getSuccessor();
|
||||
ValueRange succArgs = successor.getSuccessorInputs();
|
||||
Block *block = ®ion->front();
|
||||
Block::BlockArgListType arguments = block->getArguments();
|
||||
if (arguments.size() != succArgs.size()) {
|
||||
unsigned firstArgIdx =
|
||||
succArgs.empty() ? 0
|
||||
: succArgs[0].cast<BlockArgument>().getArgNumber();
|
||||
result |= markAllPessimisticFixpoint(arguments.take_front(firstArgIdx));
|
||||
result |= markAllPessimisticFixpoint(
|
||||
arguments.drop_front(firstArgIdx + succArgs.size()));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Type-erased wrappers that convert the abstract lattice operands to derived
|
||||
/// lattices and invoke the virtual hooks operating on the derived lattices.
|
||||
@ -379,6 +415,14 @@ private:
|
||||
branch, sourceIndex,
|
||||
llvm::makeArrayRef(derivedOperandBase, operands.size()), successors);
|
||||
}
|
||||
ChangeResult visitNonControlFlowArguments(
|
||||
Operation *op, const RegionSuccessor ®ion,
|
||||
ArrayRef<detail::AbstractLatticeElement *> operands) final {
|
||||
LatticeElement<ValueT> *const *derivedOperandBase =
|
||||
reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
|
||||
return visitNonControlFlowArguments(
|
||||
op, region, llvm::makeArrayRef(derivedOperandBase, operands.size()));
|
||||
}
|
||||
|
||||
/// Create a new uninitialized lattice element. An optional value is provided,
|
||||
/// which if valid, should be used to initialize the known conservative state
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include <queue>
|
||||
@ -113,6 +114,7 @@ private:
|
||||
/// the parent operation results.
|
||||
void visitRegionSuccessors(
|
||||
Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
|
||||
ArrayRef<AbstractLatticeElement *> operandLattices,
|
||||
function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
|
||||
|
||||
/// Visit the given terminator operation and compute any necessary lattice
|
||||
@ -460,7 +462,7 @@ void ForwardDataFlowSolver::visitRegionBranchOperation(
|
||||
if (successors.empty())
|
||||
return markAllPessimisticFixpoint(branch, branch->getResults());
|
||||
return visitRegionSuccessors(
|
||||
branch, successors, [&](Optional<unsigned> index) {
|
||||
branch, successors, operandLattices, [&](Optional<unsigned> index) {
|
||||
assert(index && "expected valid region index");
|
||||
return branch.getSuccessorEntryOperands(*index);
|
||||
});
|
||||
@ -468,6 +470,7 @@ void ForwardDataFlowSolver::visitRegionBranchOperation(
|
||||
|
||||
void ForwardDataFlowSolver::visitRegionSuccessors(
|
||||
Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
|
||||
ArrayRef<AbstractLatticeElement *> operandLattices,
|
||||
function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
|
||||
for (const RegionSuccessor &it : regionSuccessors) {
|
||||
Region *region = it.getSuccessor();
|
||||
@ -514,22 +517,25 @@ void ForwardDataFlowSolver::visitRegionSuccessors(
|
||||
if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); }))
|
||||
continue;
|
||||
|
||||
// Mark any arguments that do not receive inputs as having reached a
|
||||
// pessimistic fixpoint, we won't be able to discern if they are constant.
|
||||
// TODO: This isn't exactly ideal. There may be situations in which a
|
||||
// region operation can provide information for certain results that
|
||||
// aren't part of the control flow.
|
||||
if (succArgs.size() != arguments.size()) {
|
||||
if (succArgs.empty()) {
|
||||
markAllPessimisticFixpoint(arguments);
|
||||
continue;
|
||||
if (analysis.visitNonControlFlowArguments(
|
||||
parentOp, it, operandLattices) == ChangeResult::Change) {
|
||||
unsigned firstArgIdx =
|
||||
succArgs.empty() ? 0
|
||||
: succArgs[0].cast<BlockArgument>().getArgNumber();
|
||||
for (Value v : arguments.take_front(firstArgIdx)) {
|
||||
assert(!analysis.getLatticeElement(v).isUninitialized() &&
|
||||
"Non-control flow block arg has no lattice value after "
|
||||
"analysis callback");
|
||||
visitUsers(v);
|
||||
}
|
||||
for (Value v : arguments.drop_front(firstArgIdx + succArgs.size())) {
|
||||
assert(!analysis.getLatticeElement(v).isUninitialized() &&
|
||||
"Non-control flow block arg has no lattice value after "
|
||||
"analysis callback");
|
||||
visitUsers(v);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
|
||||
markAllPessimisticFixpointAndVisitUsers(
|
||||
arguments.take_front(firstArgIdx));
|
||||
markAllPessimisticFixpointAndVisitUsers(
|
||||
arguments.drop_front(firstArgIdx + succArgs.size()));
|
||||
}
|
||||
|
||||
// Update the lattice of arguments that have inputs from the predecessor.
|
||||
@ -573,12 +579,13 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
|
||||
// Try to get "region-like" successor operands if possible in order to
|
||||
// propagate the operand states to the successors.
|
||||
if (isRegionReturnLike(op)) {
|
||||
return visitRegionSuccessors(
|
||||
parentOp, regionSuccessors, [&](Optional<unsigned> regionIndex) {
|
||||
// Determine the individual region successor operands for the given
|
||||
// region index (if any).
|
||||
return *getRegionBranchSuccessorOperands(op, regionIndex);
|
||||
});
|
||||
auto getOperands = [&](Optional<unsigned> regionIndex) {
|
||||
// Determine the individual region successor operands for the given
|
||||
// region index (if any).
|
||||
return *getRegionBranchSuccessorOperands(op, regionIndex);
|
||||
};
|
||||
return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
|
||||
getOperands);
|
||||
}
|
||||
|
||||
// If this terminator is not "region-like", conservatively mark all of the
|
||||
|
24
mlir/test/Analysis/test-data-flow.mlir
Normal file
24
mlir/test/Analysis/test-data-flow.mlir
Normal file
@ -0,0 +1,24 @@
|
||||
// RUN: mlir-opt -test-data-flow --allow-unregistered-dialect %s 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: Testing : "loop-arg-pessimistic"
|
||||
module attributes {test.name = "loop-arg-pessimistic"} {
|
||||
func @f() -> index {
|
||||
// CHECK: Visiting : %{{.*}} = arith.constant 0
|
||||
// CHECK-NEXT: Result 0 moved from uninitialized to 1
|
||||
%c0 = arith.constant 0 : index
|
||||
// CHECK: Visiting : %{{.*}} = arith.constant 1
|
||||
// CHECK-NEXT: Result 0 moved from uninitialized to 1
|
||||
%c1 = arith.constant 1 : index
|
||||
// CHECK: Visiting region branch op : %{{.*}} = scf.for
|
||||
// CHECK: Block argument 0 moved from uninitialized to 1
|
||||
%0 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %c0) -> index {
|
||||
// CHECK: Visiting : %{{.*}} = arith.addi %{{.*}}, %{{.*}}
|
||||
// CHECK-NEXT: Arg 0 : 1
|
||||
// CHECK-NEXT: Arg 1 : 1
|
||||
// CHECK-NEXT: Result 0 moved from uninitialized to 1
|
||||
%10 = arith.addi %arg1, %arg2 : index
|
||||
scf.yield %10 : index
|
||||
}
|
||||
return %0 : index
|
||||
}
|
||||
}
|
@ -2,6 +2,7 @@
|
||||
add_mlir_library(MLIRTestAnalysis
|
||||
TestAliasAnalysis.cpp
|
||||
TestCallGraph.cpp
|
||||
TestDataFlow.cpp
|
||||
TestLiveness.cpp
|
||||
TestMatchReduction.cpp
|
||||
TestMemRefBoundCheck.cpp
|
||||
|
127
mlir/test/lib/Analysis/TestDataFlow.cpp
Normal file
127
mlir/test/lib/Analysis/TestDataFlow.cpp
Normal file
@ -0,0 +1,127 @@
|
||||
//===- TestDataFlow.cpp - Test data flow analysis system -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains test passes for defining and running a dataflow analysis.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct WasAnalyzed {
|
||||
WasAnalyzed(bool wasAnalyzed) : wasAnalyzed(wasAnalyzed) {}
|
||||
|
||||
static WasAnalyzed join(const WasAnalyzed &a, const WasAnalyzed &b) {
|
||||
return a.wasAnalyzed && b.wasAnalyzed;
|
||||
}
|
||||
|
||||
static WasAnalyzed getPessimisticValueState(MLIRContext *context) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static WasAnalyzed getPessimisticValueState(Value v) {
|
||||
return getPessimisticValueState(v.getContext());
|
||||
}
|
||||
|
||||
bool operator==(const WasAnalyzed &other) const {
|
||||
return wasAnalyzed == other.wasAnalyzed;
|
||||
}
|
||||
|
||||
bool wasAnalyzed;
|
||||
};
|
||||
|
||||
struct TestAnalysis : public ForwardDataFlowAnalysis<WasAnalyzed> {
|
||||
using ForwardDataFlowAnalysis<WasAnalyzed>::ForwardDataFlowAnalysis;
|
||||
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<WasAnalyzed> *> operands) final {
|
||||
ChangeResult ret = ChangeResult::NoChange;
|
||||
llvm::errs() << "Visiting : ";
|
||||
op->print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
|
||||
WasAnalyzed result(true);
|
||||
for (auto &pair : llvm::enumerate(operands)) {
|
||||
LatticeElement<WasAnalyzed> *elem = pair.value();
|
||||
llvm::errs() << "Arg " << pair.index();
|
||||
if (!elem->isUninitialized()) {
|
||||
llvm::errs() << " : " << elem->getValue().wasAnalyzed << "\n";
|
||||
result = WasAnalyzed::join(result, elem->getValue());
|
||||
} else {
|
||||
llvm::errs() << " uninitialized\n";
|
||||
}
|
||||
}
|
||||
for (const auto &pair : llvm::enumerate(op->getResults())) {
|
||||
LatticeElement<WasAnalyzed> &lattice = getLatticeElement(pair.value());
|
||||
llvm::errs() << "Result " << pair.index() << " moved from ";
|
||||
if (lattice.isUninitialized())
|
||||
llvm::errs() << "uninitialized";
|
||||
else
|
||||
llvm::errs() << lattice.getValue().wasAnalyzed;
|
||||
ret |= lattice.join({result});
|
||||
llvm::errs() << " to " << lattice.getValue().wasAnalyzed << "\n";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
ChangeResult visitNonControlFlowArguments(
|
||||
Operation *op, const RegionSuccessor &successor,
|
||||
ArrayRef<LatticeElement<WasAnalyzed> *> operands) final {
|
||||
ChangeResult ret = ChangeResult::NoChange;
|
||||
llvm::errs() << "Visiting region branch op : ";
|
||||
op->print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
|
||||
Region *region = successor.getSuccessor();
|
||||
Block *block = ®ion->front();
|
||||
Block::BlockArgListType arguments = block->getArguments();
|
||||
// Mark all arguments to blocks as analyzed unless they already have
|
||||
// an unanalyzed state.
|
||||
for (const auto &pair : llvm::enumerate(arguments)) {
|
||||
LatticeElement<WasAnalyzed> &lattice = getLatticeElement(pair.value());
|
||||
llvm::errs() << "Block argument " << pair.index() << " moved from ";
|
||||
if (lattice.isUninitialized())
|
||||
llvm::errs() << "uninitialized";
|
||||
else
|
||||
llvm::errs() << lattice.getValue().wasAnalyzed;
|
||||
ret |= lattice.join({true});
|
||||
llvm::errs() << " to " << lattice.getValue().wasAnalyzed << "\n";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct TestDataFlowPass
|
||||
: public PassWrapper<TestDataFlowPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataFlowPass)
|
||||
|
||||
StringRef getArgument() const final { return "test-data-flow"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Print the actions taken during a dataflow analysis.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
llvm::errs() << "Testing : " << getOperation()->getAttr("test.name")
|
||||
<< "\n";
|
||||
TestAnalysis analysis(getOperation().getContext());
|
||||
analysis.run(getOperation());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestDataFlowPass() { PassRegistration<TestDataFlowPass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
@ -70,6 +70,7 @@ void registerTestConstantFold();
|
||||
void registerTestControlFlowSink();
|
||||
void registerTestGpuSerializeToCubinPass();
|
||||
void registerTestGpuSerializeToHsacoPass();
|
||||
void registerTestDataFlowPass();
|
||||
void registerTestDataLayoutQuery();
|
||||
void registerTestDecomposeCallGraphTypes();
|
||||
void registerTestDiagnosticsPass();
|
||||
@ -167,6 +168,7 @@ void registerTestPasses() {
|
||||
mlir::test::registerTestGpuSerializeToHsacoPass();
|
||||
#endif
|
||||
mlir::test::registerTestDecomposeCallGraphTypes();
|
||||
mlir::test::registerTestDataFlowPass();
|
||||
mlir::test::registerTestDataLayoutQuery();
|
||||
mlir::test::registerTestDominancePass();
|
||||
mlir::test::registerTestDynamicPipelinePass();
|
||||
|
Loading…
Reference in New Issue
Block a user