llvm-mirror/unittests/Analysis/SparsePropagation.cpp
Matthew Simpson b85949d044 [SparsePropagation] Enable interprocedural analysis
This patch adds the ability to perform IPSCCP-like interprocedural analysis to
the generic sparse propagation solver. The patch gives clients the ability to
define their own custom LatticeKey types that the generic solver maps to custom
LatticeVal types. The custom lattice keys can be used, for example, to
distinguish among mappings for regular values, values returned from functions,
and values stored in global variables. Clients are responsible for defining how
to convert between LatticeKeys and LLVM Values by providing a specialization of
the LatticeKeyInfo template.

The added unit tests demonstrate how the generic solver can be used to perform
a simplified version of interprocedural constant propagation.

Differential Revision: https://reviews.llvm.org/D37353

llvm-svn: 315919
2017-10-16 17:44:17 +00:00

545 lines
21 KiB
C++

//===- SparsePropagation.cpp - Unit tests for the generic solver ----------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/SparsePropagation.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/IRBuilder.h"
#include "gtest/gtest.h"
using namespace llvm;
namespace {
/// To enable interprocedural analysis, we assign LLVM values to the following
/// groups. The register group represents SSA registers, the return group
/// represents the return values of functions, and the memory group represents
/// in-memory values. An LLVM Value can technically be in more than one group.
/// It's necessary to distinguish these groups so we can, for example, track a
/// global variable separately from the value stored at its location.
enum class IPOGrouping { Register, Return, Memory };
/// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
/// The PointerIntPair header provides a DenseMapInfo specialization, so using
/// these as LatticeKeys is fine.
using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
} // namespace
namespace llvm {
/// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver
/// must translate between LatticeKeys and LLVM Values when adding Values to
/// its work list and inspecting the state of control-flow related values.
template <> struct LatticeKeyInfo<TestLatticeKey> {
static inline Value *getValueFromLatticeKey(TestLatticeKey Key) {
return Key.getPointer();
}
static inline TestLatticeKey getLatticeKeyFromValue(Value *V) {
return TestLatticeKey(V, IPOGrouping::Register);
}
};
} // namespace llvm
namespace {
/// This class defines a simple test lattice value that could be used for
/// solving problems similar to constant propagation. The value is maintained
/// as a PointerIntPair.
class TestLatticeVal {
public:
/// The states of the lattices value. Only the ConstantVal state is
/// interesting; the rest are special states used by the generic solver. The
/// UntrackedVal state differs from the other three in that the generic
/// solver uses it to avoid doing unnecessary work. In particular, when a
/// value moves to the UntrackedVal state, it's users are not notified.
enum TestLatticeStateTy {
UndefinedVal,
ConstantVal,
OverdefinedVal,
UntrackedVal
};
TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {}
TestLatticeVal(Constant *C, TestLatticeStateTy State)
: LatticeVal(C, State) {}
/// Return true if this lattice value is in the Constant state. This is used
/// for checking the solver results.
bool isConstant() const { return LatticeVal.getInt() == ConstantVal; }
/// Return true if this lattice value is in the Overdefined state. This is
/// used for checking the solver results.
bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; }
bool operator==(const TestLatticeVal &RHS) const {
return LatticeVal == RHS.LatticeVal;
}
bool operator!=(const TestLatticeVal &RHS) const {
return LatticeVal != RHS.LatticeVal;
}
private:
/// A simple lattice value type for problems similar to constant propagation.
/// It holds the constant value and the lattice state.
PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal;
};
/// This class defines a simple test lattice function that could be used for
/// solving problems similar to constant propagation. The test lattice differs
/// from a "real" lattice in a few ways. First, it initializes all return
/// values, values stored in global variables, and arguments in the undefined
/// state. This means that there are no limitations on what we can track
/// interprocedurally. For simplicity, all global values in the tests will be
/// given internal linkage, since this is not something this lattice function
/// tracks. Second, it only handles the few instructions necessary for the
/// tests.
class TestLatticeFunc
: public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> {
public:
/// Construct a new test lattice function with special values for the
/// Undefined, Overdefined, and Untracked states.
TestLatticeFunc()
: AbstractLatticeFunction(
TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal),
TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal),
TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {}
/// Compute and return a TestLatticeVal for the given TestLatticeKey. For the
/// test analysis, a LatticeKey will begin in the undefined state, unless it
/// represents an LLVM Constant in the register grouping.
TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override {
if (Key.getInt() == IPOGrouping::Register)
if (auto *C = dyn_cast<Constant>(Key.getPointer()))
return TestLatticeVal(C, TestLatticeVal::ConstantVal);
return getUndefVal();
}
/// Merge the two given lattice values. This merge should be equivalent to
/// what is done for constant propagation. That is, the resulting lattice
/// value is constant only if the two given lattice values are constant and
/// hold the same value.
TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override {
if (X == getUntrackedVal() || Y == getUntrackedVal())
return getUntrackedVal();
if (X == getOverdefinedVal() || Y == getOverdefinedVal())
return getOverdefinedVal();
if (X == getUndefVal() && Y == getUndefVal())
return getUndefVal();
if (X == getUndefVal())
return Y;
if (Y == getUndefVal())
return X;
if (X == Y)
return X;
return getOverdefinedVal();
}
/// Compute the lattice values that change as a result of executing the given
/// instruction. We only handle the few instructions needed for the tests.
void ComputeInstructionState(
Instruction &I, DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override {
switch (I.getOpcode()) {
case Instruction::Call:
return visitCallSite(cast<CallInst>(&I), ChangedValues, SS);
case Instruction::Ret:
return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
case Instruction::Store:
return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
default:
return visitInst(I, ChangedValues, SS);
}
}
private:
/// Handle call sites. The state of a called function's argument is the merge
/// of the current formal argument state with the call site's corresponding
/// actual argument state. The call site state is the merge of the call site
/// state with the returned value state of the called function.
void visitCallSite(CallSite CS,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
Function *F = CS.getCalledFunction();
Instruction *I = CS.getInstruction();
auto RegI = TestLatticeKey(I, IPOGrouping::Register);
if (!F) {
ChangedValues[RegI] = getOverdefinedVal();
return;
}
SS.MarkBlockExecutable(&F->front());
for (Argument &A : F->args()) {
auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register);
auto RegActual =
TestLatticeKey(CS.getArgument(A.getArgNo()), IPOGrouping::Register);
ChangedValues[RegFormal] =
MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
}
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
ChangedValues[RegI] =
MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
}
/// Handle return instructions. The function's return state is the merge of
/// the returned value state and the function's current return state.
void visitReturn(ReturnInst &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
Function *F = I.getParent()->getParent();
if (F->getReturnType()->isVoidTy())
return;
auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register);
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
ChangedValues[RetF] =
MergeValues(SS.getValueState(RegR), SS.getValueState(RetF));
}
/// Handle store instructions. If the pointer operand of the store is a
/// global variable, we attempt to track the value. The global variable state
/// is the merge of the stored value state with the current global variable
/// state.
void visitStore(StoreInst &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
if (!GV)
return;
auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register);
auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory);
ChangedValues[MemPtr] =
MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr));
}
/// Handle all other instructions. All other instructions are marked
/// overdefined.
void visitInst(Instruction &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
ChangedValues[RegI] = getOverdefinedVal();
}
};
/// This class defines the common data used for all of the tests. The tests
/// should add code to the module and then run the solver.
class SparsePropagationTest : public testing::Test {
protected:
LLVMContext Context;
Module M;
IRBuilder<> Builder;
TestLatticeFunc Lattice;
SparseSolver<TestLatticeKey, TestLatticeVal> Solver;
public:
SparsePropagationTest()
: M("", Context), Builder(Context), Solver(&Lattice) {}
};
} // namespace
/// Test that we mark discovered functions executable.
///
/// define internal void @f() {
/// call void @g()
/// ret void
/// }
///
/// define internal void @g() {
/// call void @f()
/// ret void
/// }
///
/// For this test, we initially mark "f" executable, and the solver discovers
/// "g" because of the call in "f". The mutually recursive call in "g" also
/// tests that we don't add a block to the basic block work list if it is
/// already executable. Doing so would put the solver into an infinite loop.
TEST_F(SparsePropagationTest, MarkBlockExecutable) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateCall(G);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateCall(F);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.Solve();
EXPECT_TRUE(Solver.isBlockExecutable(GEntry));
}
/// Test that we propagate information through global variables.
///
/// @gv = internal global i64
///
/// define internal void @f() {
/// store i64 1, i64* @gv
/// ret void
/// }
///
/// define internal void @g() {
/// store i64 1, i64* @gv
/// ret void
/// }
///
/// For this test, we initially mark both "f" and "g" executable, and the
/// solver computes the lattice state of the global variable as constant.
TEST_F(SparsePropagationTest, GlobalVariableConstant) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
GlobalVariable *GV =
new GlobalVariable(M, Builder.getInt64Ty(), false,
GlobalValue::InternalLinkage, nullptr, "gv");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.MarkBlockExecutable(GEntry);
Solver.Solve();
auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant());
}
/// Test that we propagate information through global variables.
///
/// @gv = internal global i64
///
/// define internal void @f() {
/// store i64 0, i64* @gv
/// ret void
/// }
///
/// define internal void @g() {
/// store i64 1, i64* @gv
/// ret void
/// }
///
/// For this test, we initially mark both "f" and "g" executable, and the
/// solver computes the lattice state of the global variable as overdefined.
TEST_F(SparsePropagationTest, GlobalVariableOverDefined) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
GlobalVariable *GV =
new GlobalVariable(M, Builder.getInt64Ty(), false,
GlobalValue::InternalLinkage, nullptr, "gv");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateStore(Builder.getInt64(0), GV);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.MarkBlockExecutable(GEntry);
Solver.Solve();
auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined());
}
/// Test that we propagate information through function returns.
///
/// define internal i64 @f(i1* %cond) {
/// if:
/// %0 = load i1, i1* %cond
/// br i1 %0, label %then, label %else
///
/// then:
/// ret i64 1
///
/// else:
/// ret i64 1
/// }
///
/// For this test, we initially mark "f" executable, and the solver computes
/// the return value of the function as constant.
TEST_F(SparsePropagationTest, FunctionDefined) {
Function *F =
Function::Create(FunctionType::get(Builder.getInt64Ty(),
{Type::getInt1PtrTy(Context)}, false),
GlobalValue::InternalLinkage, "f", &M);
BasicBlock *If = BasicBlock::Create(Context, "if", F);
BasicBlock *Then = BasicBlock::Create(Context, "then", F);
BasicBlock *Else = BasicBlock::Create(Context, "else", F);
F->arg_begin()->setName("cond");
Builder.SetInsertPoint(If);
LoadInst *Cond = Builder.CreateLoad(F->arg_begin());
Builder.CreateCondBr(Cond, Then, Else);
Builder.SetInsertPoint(Then);
Builder.CreateRet(Builder.getInt64(1));
Builder.SetInsertPoint(Else);
Builder.CreateRet(Builder.getInt64(1));
Solver.MarkBlockExecutable(If);
Solver.Solve();
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant());
}
/// Test that we propagate information through function returns.
///
/// define internal i64 @f(i1* %cond) {
/// if:
/// %0 = load i1, i1* %cond
/// br i1 %0, label %then, label %else
///
/// then:
/// ret i64 0
///
/// else:
/// ret i64 1
/// }
///
/// For this test, we initially mark "f" executable, and the solver computes
/// the return value of the function as overdefined.
TEST_F(SparsePropagationTest, FunctionOverDefined) {
Function *F =
Function::Create(FunctionType::get(Builder.getInt64Ty(),
{Type::getInt1PtrTy(Context)}, false),
GlobalValue::InternalLinkage, "f", &M);
BasicBlock *If = BasicBlock::Create(Context, "if", F);
BasicBlock *Then = BasicBlock::Create(Context, "then", F);
BasicBlock *Else = BasicBlock::Create(Context, "else", F);
F->arg_begin()->setName("cond");
Builder.SetInsertPoint(If);
LoadInst *Cond = Builder.CreateLoad(F->arg_begin());
Builder.CreateCondBr(Cond, Then, Else);
Builder.SetInsertPoint(Then);
Builder.CreateRet(Builder.getInt64(0));
Builder.SetInsertPoint(Else);
Builder.CreateRet(Builder.getInt64(1));
Solver.MarkBlockExecutable(If);
Solver.Solve();
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined());
}
/// Test that we propagate information through arguments.
///
/// define internal void @f() {
/// call void @g(i64 0, i64 1)
/// call void @g(i64 1, i64 1)
/// ret void
/// }
///
/// define internal void @g(i64 %a, i64 %b) {
/// ret void
/// }
///
/// For this test, we initially mark "f" executable, and the solver discovers
/// "g" because of the calls in "f". The solver computes the state of argument
/// "a" as overdefined and the state of "b" as constant.
///
/// In addition, this test demonstrates that ComputeInstructionState can alter
/// the state of multiple lattice values, in addition to the one associated
/// with the instruction definition. Each call instruction in this test updates
/// the state of arguments "a" and "b".
TEST_F(SparsePropagationTest, ComputeInstructionState) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(
FunctionType::get(Builder.getVoidTy(),
{Builder.getInt64Ty(), Builder.getInt64Ty()}, false),
GlobalValue::InternalLinkage, "g", &M);
Argument *A = G->arg_begin();
Argument *B = std::next(G->arg_begin());
A->setName("a");
B->setName("b");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)});
Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)});
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.Solve();
auto RegA = TestLatticeKey(A, IPOGrouping::Register);
auto RegB = TestLatticeKey(B, IPOGrouping::Register);
EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined());
EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant());
}
/// Test that we can handle exceptional terminator instructions.
///
/// declare internal void @p()
///
/// declare internal void @g()
///
/// define internal void @f() personality i8* bitcast (void ()* @p to i8*) {
/// entry:
/// invoke void @g()
/// to label %exit unwind label %catch.pad
///
/// catch.pad:
/// %0 = catchswitch within none [label %catch.body] unwind to caller
///
/// catch.body:
/// %1 = catchpad within %0 []
/// catchret from %1 to label %exit
///
/// exit:
/// ret void
/// }
///
/// For this test, we initially mark the entry block executable. The solver
/// then discovers the rest of the blocks in the function are executable.
TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) {
Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "p", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Constant *C =
ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy());
F->setPersonalityFn(C);
BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F);
BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F);
BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
Builder.SetInsertPoint(Entry);
Builder.CreateInvoke(G, Exit, Pad);
Builder.SetInsertPoint(Pad);
CatchSwitchInst *CatchSwitch =
Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1);
CatchSwitch->addHandler(Body);
Builder.SetInsertPoint(Body);
CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {});
Builder.CreateCatchRet(CatchPad, Exit);
Builder.SetInsertPoint(Exit);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(Entry);
Solver.Solve();
EXPECT_TRUE(Solver.isBlockExecutable(Pad));
EXPECT_TRUE(Solver.isBlockExecutable(Body));
EXPECT_TRUE(Solver.isBlockExecutable(Exit));
}