[mlir] Allow data flow analysis of non-control flow branch arguments

This commit adds the visitNonControlFlowArguments method to
DataFlowAnalysis, allowing analyses to provide lattice values for the
arguments to a RegionSuccessor block that aren't directly tied to an
op's inputs. For example, integer range interface can use this method
to infer bounds for the step values in loops.

This method has a default implementation that keeps the old behavior
of assigning a pessimistic fixedpoint state to all such arguments.

Reviewed By: Mogball, rriddle

Differential Revision: https://reviews.llvm.org/D124021
This commit is contained in:
Krzysztof Drewniak 2022-04-19 17:10:31 +00:00
parent 7164c5f051
commit d35f7f254f
6 changed files with 226 additions and 21 deletions

View File

@ -250,6 +250,15 @@ public:
ArrayRef<AbstractLatticeElement *> operands,
SmallVectorImpl<RegionSuccessor> &successors) = 0;
/// Given a operation with successor regions, one of those regions,
/// and the lattice elements corresponding to the operation's
/// arguments, compute the latice values for block arguments
/// that are not accounted for by the branching control flow (ex. the
/// bounds of loops).
virtual ChangeResult
visitNonControlFlowArguments(Operation *op, const RegionSuccessor &region,
ArrayRef<AbstractLatticeElement *> operands) = 0;
/// Create a new uninitialized lattice element. An optional value is provided
/// which, if valid, should be used to initialize the known conservative state
/// of the lattice.
@ -347,6 +356,33 @@ protected:
branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
}
/// Given a operation with successor regions, one of those regions,
/// and the lattice elements corresponding to the operation's
/// arguments, compute the latice values for block arguments
/// that are not accounted for by the branching control flow (ex. the
/// bounds of loops). By default, this method marks all such lattice elements
/// as having reached a pessimistic fixpoint. The region in the
/// RegionSuccessor and the operand latice elements are guaranteed to be
/// non-null.
virtual ChangeResult
visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
ArrayRef<LatticeElement<ValueT> *> operands) {
ChangeResult result = ChangeResult::NoChange;
Region *region = successor.getSuccessor();
ValueRange succArgs = successor.getSuccessorInputs();
Block *block = &region->front();
Block::BlockArgListType arguments = block->getArguments();
if (arguments.size() != succArgs.size()) {
unsigned firstArgIdx =
succArgs.empty() ? 0
: succArgs[0].cast<BlockArgument>().getArgNumber();
result |= markAllPessimisticFixpoint(arguments.take_front(firstArgIdx));
result |= markAllPessimisticFixpoint(
arguments.drop_front(firstArgIdx + succArgs.size()));
}
return result;
}
private:
/// Type-erased wrappers that convert the abstract lattice operands to derived
/// lattices and invoke the virtual hooks operating on the derived lattices.
@ -379,6 +415,14 @@ private:
branch, sourceIndex,
llvm::makeArrayRef(derivedOperandBase, operands.size()), successors);
}
ChangeResult visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &region,
ArrayRef<detail::AbstractLatticeElement *> operands) final {
LatticeElement<ValueT> *const *derivedOperandBase =
reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
return visitNonControlFlowArguments(
op, region, llvm::makeArrayRef(derivedOperandBase, operands.size()));
}
/// Create a new uninitialized lattice element. An optional value is provided,
/// which if valid, should be used to initialize the known conservative state

View File

@ -10,6 +10,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include <queue>
@ -113,6 +114,7 @@ private:
/// the parent operation results.
void visitRegionSuccessors(
Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
ArrayRef<AbstractLatticeElement *> operandLattices,
function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
/// Visit the given terminator operation and compute any necessary lattice
@ -460,7 +462,7 @@ void ForwardDataFlowSolver::visitRegionBranchOperation(
if (successors.empty())
return markAllPessimisticFixpoint(branch, branch->getResults());
return visitRegionSuccessors(
branch, successors, [&](Optional<unsigned> index) {
branch, successors, operandLattices, [&](Optional<unsigned> index) {
assert(index && "expected valid region index");
return branch.getSuccessorEntryOperands(*index);
});
@ -468,6 +470,7 @@ void ForwardDataFlowSolver::visitRegionBranchOperation(
void ForwardDataFlowSolver::visitRegionSuccessors(
Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
ArrayRef<AbstractLatticeElement *> operandLattices,
function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
for (const RegionSuccessor &it : regionSuccessors) {
Region *region = it.getSuccessor();
@ -514,22 +517,25 @@ void ForwardDataFlowSolver::visitRegionSuccessors(
if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); }))
continue;
// Mark any arguments that do not receive inputs as having reached a
// pessimistic fixpoint, we won't be able to discern if they are constant.
// TODO: This isn't exactly ideal. There may be situations in which a
// region operation can provide information for certain results that
// aren't part of the control flow.
if (succArgs.size() != arguments.size()) {
if (succArgs.empty()) {
markAllPessimisticFixpoint(arguments);
continue;
if (analysis.visitNonControlFlowArguments(
parentOp, it, operandLattices) == ChangeResult::Change) {
unsigned firstArgIdx =
succArgs.empty() ? 0
: succArgs[0].cast<BlockArgument>().getArgNumber();
for (Value v : arguments.take_front(firstArgIdx)) {
assert(!analysis.getLatticeElement(v).isUninitialized() &&
"Non-control flow block arg has no lattice value after "
"analysis callback");
visitUsers(v);
}
for (Value v : arguments.drop_front(firstArgIdx + succArgs.size())) {
assert(!analysis.getLatticeElement(v).isUninitialized() &&
"Non-control flow block arg has no lattice value after "
"analysis callback");
visitUsers(v);
}
}
unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
markAllPessimisticFixpointAndVisitUsers(
arguments.take_front(firstArgIdx));
markAllPessimisticFixpointAndVisitUsers(
arguments.drop_front(firstArgIdx + succArgs.size()));
}
// Update the lattice of arguments that have inputs from the predecessor.
@ -573,12 +579,13 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
// Try to get "region-like" successor operands if possible in order to
// propagate the operand states to the successors.
if (isRegionReturnLike(op)) {
return visitRegionSuccessors(
parentOp, regionSuccessors, [&](Optional<unsigned> regionIndex) {
// Determine the individual region successor operands for the given
// region index (if any).
return *getRegionBranchSuccessorOperands(op, regionIndex);
});
auto getOperands = [&](Optional<unsigned> regionIndex) {
// Determine the individual region successor operands for the given
// region index (if any).
return *getRegionBranchSuccessorOperands(op, regionIndex);
};
return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
getOperands);
}
// If this terminator is not "region-like", conservatively mark all of the

View File

@ -0,0 +1,24 @@
// RUN: mlir-opt -test-data-flow --allow-unregistered-dialect %s 2>&1 | FileCheck %s
// CHECK-LABEL: Testing : "loop-arg-pessimistic"
module attributes {test.name = "loop-arg-pessimistic"} {
func @f() -> index {
// CHECK: Visiting : %{{.*}} = arith.constant 0
// CHECK-NEXT: Result 0 moved from uninitialized to 1
%c0 = arith.constant 0 : index
// CHECK: Visiting : %{{.*}} = arith.constant 1
// CHECK-NEXT: Result 0 moved from uninitialized to 1
%c1 = arith.constant 1 : index
// CHECK: Visiting region branch op : %{{.*}} = scf.for
// CHECK: Block argument 0 moved from uninitialized to 1
%0 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %c0) -> index {
// CHECK: Visiting : %{{.*}} = arith.addi %{{.*}}, %{{.*}}
// CHECK-NEXT: Arg 0 : 1
// CHECK-NEXT: Arg 1 : 1
// CHECK-NEXT: Result 0 moved from uninitialized to 1
%10 = arith.addi %arg1, %arg2 : index
scf.yield %10 : index
}
return %0 : index
}
}

View File

@ -2,6 +2,7 @@
add_mlir_library(MLIRTestAnalysis
TestAliasAnalysis.cpp
TestCallGraph.cpp
TestDataFlow.cpp
TestLiveness.cpp
TestMatchReduction.cpp
TestMemRefBoundCheck.cpp

View File

@ -0,0 +1,127 @@
//===- TestDataFlow.cpp - Test data flow analysis system -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains test passes for defining and running a dataflow analysis.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
namespace {
struct WasAnalyzed {
WasAnalyzed(bool wasAnalyzed) : wasAnalyzed(wasAnalyzed) {}
static WasAnalyzed join(const WasAnalyzed &a, const WasAnalyzed &b) {
return a.wasAnalyzed && b.wasAnalyzed;
}
static WasAnalyzed getPessimisticValueState(MLIRContext *context) {
return false;
}
static WasAnalyzed getPessimisticValueState(Value v) {
return getPessimisticValueState(v.getContext());
}
bool operator==(const WasAnalyzed &other) const {
return wasAnalyzed == other.wasAnalyzed;
}
bool wasAnalyzed;
};
struct TestAnalysis : public ForwardDataFlowAnalysis<WasAnalyzed> {
using ForwardDataFlowAnalysis<WasAnalyzed>::ForwardDataFlowAnalysis;
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<WasAnalyzed> *> operands) final {
ChangeResult ret = ChangeResult::NoChange;
llvm::errs() << "Visiting : ";
op->print(llvm::errs());
llvm::errs() << "\n";
WasAnalyzed result(true);
for (auto &pair : llvm::enumerate(operands)) {
LatticeElement<WasAnalyzed> *elem = pair.value();
llvm::errs() << "Arg " << pair.index();
if (!elem->isUninitialized()) {
llvm::errs() << " : " << elem->getValue().wasAnalyzed << "\n";
result = WasAnalyzed::join(result, elem->getValue());
} else {
llvm::errs() << " uninitialized\n";
}
}
for (const auto &pair : llvm::enumerate(op->getResults())) {
LatticeElement<WasAnalyzed> &lattice = getLatticeElement(pair.value());
llvm::errs() << "Result " << pair.index() << " moved from ";
if (lattice.isUninitialized())
llvm::errs() << "uninitialized";
else
llvm::errs() << lattice.getValue().wasAnalyzed;
ret |= lattice.join({result});
llvm::errs() << " to " << lattice.getValue().wasAnalyzed << "\n";
}
return ret;
}
ChangeResult visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &successor,
ArrayRef<LatticeElement<WasAnalyzed> *> operands) final {
ChangeResult ret = ChangeResult::NoChange;
llvm::errs() << "Visiting region branch op : ";
op->print(llvm::errs());
llvm::errs() << "\n";
Region *region = successor.getSuccessor();
Block *block = &region->front();
Block::BlockArgListType arguments = block->getArguments();
// Mark all arguments to blocks as analyzed unless they already have
// an unanalyzed state.
for (const auto &pair : llvm::enumerate(arguments)) {
LatticeElement<WasAnalyzed> &lattice = getLatticeElement(pair.value());
llvm::errs() << "Block argument " << pair.index() << " moved from ";
if (lattice.isUninitialized())
llvm::errs() << "uninitialized";
else
llvm::errs() << lattice.getValue().wasAnalyzed;
ret |= lattice.join({true});
llvm::errs() << " to " << lattice.getValue().wasAnalyzed << "\n";
}
return ret;
}
};
struct TestDataFlowPass
: public PassWrapper<TestDataFlowPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataFlowPass)
StringRef getArgument() const final { return "test-data-flow"; }
StringRef getDescription() const final {
return "Print the actions taken during a dataflow analysis.";
}
void runOnOperation() override {
llvm::errs() << "Testing : " << getOperation()->getAttr("test.name")
<< "\n";
TestAnalysis analysis(getOperation().getContext());
analysis.run(getOperation());
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestDataFlowPass() { PassRegistration<TestDataFlowPass>(); }
} // namespace test
} // namespace mlir

View File

@ -70,6 +70,7 @@ void registerTestConstantFold();
void registerTestControlFlowSink();
void registerTestGpuSerializeToCubinPass();
void registerTestGpuSerializeToHsacoPass();
void registerTestDataFlowPass();
void registerTestDataLayoutQuery();
void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass();
@ -167,6 +168,7 @@ void registerTestPasses() {
mlir::test::registerTestGpuSerializeToHsacoPass();
#endif
mlir::test::registerTestDecomposeCallGraphTypes();
mlir::test::registerTestDataFlowPass();
mlir::test::registerTestDataLayoutQuery();
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();