Remove remaining usages of OperationInst in lib/Transforms.

PiperOrigin-RevId: 232323671
This commit is contained in:
River Riddle 2019-02-04 10:38:47 -08:00 committed by jpienaar
parent 44e040dd63
commit b499277fb6
20 changed files with 250 additions and 316 deletions

View File

@ -39,10 +39,10 @@ using namespace mlir;
namespace {
// TODO(riverriddle) Handle commutative operations.
struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> {
static unsigned getHashValue(const OperationInst *op) {
struct SimpleOperationInfo : public llvm::DenseMapInfo<Instruction *> {
static unsigned getHashValue(const Instruction *op) {
// Hash the operations based upon their:
// - OperationInst Name
// - Instruction Name
// - Attributes
// - Result Types
// - Operands
@ -51,7 +51,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> {
hash_combine_range(op->result_type_begin(), op->result_type_end()),
hash_combine_range(op->operand_begin(), op->operand_end()));
}
static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) {
static bool isEqual(const Instruction *lhs, const Instruction *rhs) {
if (lhs == rhs)
return true;
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
@ -89,8 +89,8 @@ struct CSE : public FunctionPass {
/// Shared implementation of operation elimination and scoped map definitions.
using AllocatorTy = llvm::RecyclingAllocator<
llvm::BumpPtrAllocator,
llvm::ScopedHashTableVal<OperationInst *, OperationInst *>>;
using ScopedMapTy = llvm::ScopedHashTable<OperationInst *, OperationInst *,
llvm::ScopedHashTableVal<Instruction *, Instruction *>>;
using ScopedMapTy = llvm::ScopedHashTable<Instruction *, Instruction *,
SimpleOperationInfo, AllocatorTy>;
/// Represents a single entry in the depth first traversal of a CFG.
@ -111,7 +111,7 @@ struct CSE : public FunctionPass {
/// Attempt to eliminate a redundant operation. Returns true if the operation
/// was marked for removal, false otherwise.
bool simplifyOperation(OperationInst *op);
bool simplifyOperation(Instruction *op);
void simplifyBlock(Block *bb);
@ -122,14 +122,14 @@ private:
ScopedMapTy knownValues;
/// Operations marked as dead and to be erased.
std::vector<OperationInst *> opsToErase;
std::vector<Instruction *> opsToErase;
};
} // end anonymous namespace
char CSE::passID = 0;
/// Attempt to eliminate a redundant operation.
bool CSE::simplifyOperation(OperationInst *op) {
bool CSE::simplifyOperation(Instruction *op) {
// TODO(riverriddle) We currently only eliminate non side-effecting
// operations.
if (!op->hasNoSideEffect())
@ -166,23 +166,16 @@ bool CSE::simplifyOperation(OperationInst *op) {
void CSE::simplifyBlock(Block *bb) {
for (auto &i : *bb) {
switch (i.getKind()) {
case Instruction::Kind::OperationInst: {
auto *opInst = cast<OperationInst>(&i);
// If the operation is simplified, we don't process any held block lists.
if (simplifyOperation(&i))
continue;
// If the operation is simplified, we don't process any held block lists.
if (simplifyOperation(opInst))
continue;
// Simplify any held blocks.
for (auto &blockList : opInst->getBlockLists()) {
for (auto &b : blockList) {
ScopedMapTy::ScopeTy scope(knownValues);
simplifyBlock(&b);
}
// Simplify any held blocks.
for (auto &blockList : i.getBlockLists()) {
for (auto &b : blockList) {
ScopedMapTy::ScopeTy scope(knownValues);
simplifyBlock(&b);
}
break;
}
}
}
}

View File

@ -48,7 +48,7 @@ namespace {
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
PassResult runOnFunction(Function *f) override;
void visitInstruction(OperationInst *opInst);
void visitInstruction(Instruction *opInst);
SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps;
@ -64,14 +64,12 @@ FunctionPass *mlir::createComposeAffineMapsPass() {
}
static bool affineApplyOp(const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
return opInst.isa<AffineApplyOp>();
return inst.isa<AffineApplyOp>();
}
void ComposeAffineMaps::visitInstruction(OperationInst *opInst) {
if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) {
void ComposeAffineMaps::visitInstruction(Instruction *opInst) {
if (auto afOp = opInst->dyn_cast<AffineApplyOp>())
affineApplyOps.push_back(afOp);
}
}
PassResult ComposeAffineMaps::runOnFunction(Function *f) {

View File

@ -33,11 +33,11 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
// All constants in the function post folding.
SmallVector<Value *, 8> existingConstants;
// Operations that were folded and that need to be erased.
std::vector<OperationInst *> opInstsToErase;
std::vector<Instruction *> opInstsToErase;
bool foldOperation(OperationInst *op,
bool foldOperation(Instruction *op,
SmallVectorImpl<Value *> &existingConstants);
void visitInstruction(OperationInst *op);
void visitInstruction(Instruction *op);
PassResult runOnFunction(Function *f) override;
static char passID;
@ -49,7 +49,7 @@ char ConstantFold::passID = 0;
/// Attempt to fold the specified operation, updating the IR to match. If
/// constants are found, we keep track of them in the existingConstants list.
///
void ConstantFold::visitInstruction(OperationInst *op) {
void ConstantFold::visitInstruction(Instruction *op) {
// If this operation is an AffineForOp, then fold the bounds.
if (auto forOp = op->dyn_cast<AffineForOp>()) {
constantFoldBounds(forOp);

View File

@ -50,7 +50,7 @@ private:
// Utility that looks up a list of value in the value remapping table. Returns
// an empty vector if one of the values is not mapped yet.
SmallVector<Value *, 4>
lookupValues(const llvm::iterator_range<OperationInst::const_operand_iterator>
lookupValues(const llvm::iterator_range<Instruction::const_operand_iterator>
&operands);
// Converts the given function to the dialect using hooks defined in
@ -61,13 +61,13 @@ private:
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
// passes them to `converter->rewriteTerminator` function defined in the
// pattern, together with `builder`.
bool convertOpWithSuccessors(DialectOpConversion *converter,
OperationInst *op, FuncBuilder &builder);
bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op,
FuncBuilder &builder);
// Converts an operation without successors. Extracts the converted operands
// from `valueRemapping` and passes them to the `converter->rewrite` function
// defined in the pattern, together with `builder`.
bool convertOp(DialectOpConversion *converter, OperationInst *op,
bool convertOp(DialectOpConversion *converter, Instruction *op,
FuncBuilder &builder);
// Converts a block by traversing its instructions sequentially, looking for
@ -104,8 +104,7 @@ private:
} // end namespace mlir
SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
const llvm::iterator_range<OperationInst::const_operand_iterator>
&operands) {
const llvm::iterator_range<Instruction::const_operand_iterator> &operands) {
SmallVector<Value *, 4> remapped;
remapped.reserve(llvm::size(operands));
for (const Value *operand : operands) {
@ -118,7 +117,7 @@ SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
}
bool impl::FunctionConversion::convertOpWithSuccessors(
DialectOpConversion *converter, OperationInst *op, FuncBuilder &builder) {
DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) {
SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors());
SmallVector<Value *, 4> operands = lookupValues(op->getOperands());
@ -149,7 +148,7 @@ bool impl::FunctionConversion::convertOpWithSuccessors(
}
bool impl::FunctionConversion::convertOp(DialectOpConversion *converter,
OperationInst *op,
Instruction *op,
FuncBuilder &builder) {
auto operands = lookupValues(op->getOperands());
assert((!operands.empty() || op->getNumOperands() == 0) &&
@ -174,24 +173,22 @@ bool impl::FunctionConversion::convertBlock(
// Iterate over ops and convert them.
for (Instruction &inst : *block) {
auto op = dyn_cast<OperationInst>(&inst);
if (!op) {
inst.emitError("unsupported instruction (For/If)");
if (inst.getNumBlockLists() != 0) {
inst.emitError("unsupported region instruction");
return true;
}
// Find the first matching conversion and apply it.
bool converted = false;
for (auto *conversion : conversions) {
if (!conversion->match(op))
if (!conversion->match(&inst))
continue;
if (op->isTerminator() && op->getNumSuccessors() > 0) {
if (convertOpWithSuccessors(conversion, op, builder))
return true;
} else {
if (convertOp(conversion, op, builder))
if (inst.isTerminator() && inst.getNumSuccessors() > 0) {
if (convertOpWithSuccessors(conversion, &inst, builder))
return true;
} else if (convertOp(conversion, &inst, builder)) {
return true;
}
converted = true;
break;

View File

@ -157,8 +157,7 @@ static void getMultiLevelStrides(const MemRefRegion &region,
/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
/// enclosing loop IVs of opInst (starting from the outermost) that the region
/// is parametric on.
static bool getFullMemRefAsRegion(OperationInst *opInst,
unsigned numParamLoopIVs,
static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs,
MemRefRegion *region) {
unsigned rank;
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
@ -563,7 +562,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
fastBufferMap.clear();
// Walk this range of instructions to gather all memory regions.
block->walk(begin, end, [&](OperationInst *opInst) {
block->walk(begin, end, [&](Instruction *opInst) {
// Gather regions to allocate to buffers in faster memory space.
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)

View File

@ -114,11 +114,11 @@ namespace {
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
SmallVector<OpPointer<AffineForOp>, 4> forOps;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
SmallVector<Instruction *, 4> loadOpInsts;
SmallVector<Instruction *, 4> storeOpInsts;
bool hasNonForRegion = false;
void visitInstruction(OperationInst *opInst) {
void visitInstruction(Instruction *opInst) {
if (opInst->isa<AffineForOp>())
forOps.push_back(opInst->cast<AffineForOp>());
else if (opInst->getNumBlockLists() != 0)
@ -131,7 +131,7 @@ public:
};
// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
static bool isMemRefDereferencingOp(const OperationInst &op) {
static bool isMemRefDereferencingOp(const Instruction &op) {
if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
op.isa<DmaWaitOp>())
return true;
@ -153,9 +153,9 @@ public:
// The top-level statment which is (or contains) loads/stores.
Instruction *inst;
// List of load operations.
SmallVector<OperationInst *, 4> loads;
SmallVector<Instruction *, 4> loads;
// List of store op insts.
SmallVector<OperationInst *, 4> stores;
SmallVector<Instruction *, 4> stores;
Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
// Returns the load op count for 'memref'.
@ -258,16 +258,13 @@ public:
for (auto *storeOpInst : node->stores) {
auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
auto *inst = memref->getDefiningInst();
auto *opInst = dyn_cast_or_null<OperationInst>(inst);
// Return false if 'memref' is a function argument.
if (opInst == nullptr)
// Return false if 'memref' is a block argument.
if (!inst)
return true;
// Return false if any use of 'memref' escapes the function.
for (auto &use : memref->getUses()) {
auto *user = dyn_cast<OperationInst>(use.getOwner());
if (!user || !isMemRefDereferencingOp(*user))
for (auto &use : memref->getUses())
if (!isMemRefDereferencingOp(*use.getOwner()))
return true;
}
}
return false;
}
@ -461,8 +458,8 @@ public:
}
// Adds ops in 'loads' and 'stores' to node at 'id'.
void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
const SmallVectorImpl<OperationInst *> &stores) {
void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
const SmallVectorImpl<Instruction *> &stores) {
Node *node = getNode(id);
for (auto *loadOpInst : loads)
node->loads.push_back(loadOpInst);
@ -509,7 +506,7 @@ bool MemRefDependenceGraph::init(Function *f) {
DenseMap<Instruction *, unsigned> forToNodeMap;
for (auto &inst : f->front()) {
if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) {
if (auto forOp = inst.dyn_cast<AffineForOp>()) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
@ -530,30 +527,28 @@ bool MemRefDependenceGraph::init(Function *f) {
}
forToNodeMap[&inst] = node.id;
nodes.insert({node.id, node});
} else if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &inst);
node.loads.push_back(opInst);
auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &inst);
node.stores.push_back(opInst);
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (opInst->getNumBlockLists() != 0) {
// Return false if another region is found (not currently supported).
return false;
} else if (opInst->getNumResults() > 0 && !opInst->use_empty()) {
// Create graph node for top-level producer of SSA values, which
// could be used by loop nest nodes.
Node node(nextNodeId++, &inst);
nodes.insert({node.id, node});
}
} else if (auto loadOp = inst.dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &inst);
node.loads.push_back(&inst);
auto *memref = inst.cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (auto storeOp = inst.dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &inst);
node.stores.push_back(&inst);
auto *memref = inst.cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (inst.getNumBlockLists() != 0) {
// Return false if another region is found (not currently supported).
return false;
} else if (inst.getNumResults() > 0 && !inst.use_empty()) {
// Create graph node for top-level producer of SSA values, which
// could be used by loop nest nodes.
Node node(nextNodeId++, &inst);
nodes.insert({node.id, node});
}
}
@ -563,12 +558,11 @@ bool MemRefDependenceGraph::init(Function *f) {
const Node &node = idAndNode.second;
if (!node.loads.empty() || !node.stores.empty())
continue;
auto *opInst = cast<OperationInst>(node.inst);
auto *opInst = node.inst;
for (auto *value : opInst->getResults()) {
for (auto &use : value->getUses()) {
auto *userOpInst = cast<OperationInst>(use.getOwner());
SmallVector<OpPointer<AffineForOp>, 4> loops;
getLoopIVs(*userOpInst, &loops);
getLoopIVs(*use.getOwner(), &loops);
if (loops.empty())
continue;
assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
@ -619,7 +613,7 @@ public:
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
void visitInstruction(OperationInst *opInst) {
void visitInstruction(Instruction *opInst) {
auto forOp = opInst->dyn_cast<AffineForOp>();
if (!forOp)
return;
@ -627,8 +621,7 @@ public:
auto *forInst = forOp->getInstruction();
auto *parentInst = forOp->getInstruction()->getParentInst();
if (parentInst != nullptr) {
assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() &&
"Expected parent AffineForOp");
assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
// Add mapping to 'forOp' from its parent AffineForOp.
stats->loopMap[parentInst].push_back(forOp);
}
@ -637,8 +630,7 @@ public:
unsigned count = 0;
stats->opCountMap[forInst] = 0;
for (auto &inst : *forOp->getBody()) {
if (!(cast<OperationInst>(inst).isa<AffineForOp>() ||
cast<OperationInst>(inst).isa<AffineIfOp>()))
if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
++count;
}
stats->opCountMap[forInst] = count;
@ -723,7 +715,7 @@ static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
// was encountered).
// TODO(andydavis) Make this work with non-unit step loops.
static bool buildSliceTripCountMap(
OperationInst *srcOpInst, ComputationSliceState *sliceState,
Instruction *srcOpInst, ComputationSliceState *sliceState,
llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
@ -755,10 +747,10 @@ static bool buildSliceTripCountMap(
// adds them to 'dstLoads'.
static void
moveLoadsAccessingMemrefTo(Value *memref,
SmallVectorImpl<OperationInst *> *srcLoads,
SmallVectorImpl<OperationInst *> *dstLoads) {
SmallVectorImpl<Instruction *> *srcLoads,
SmallVectorImpl<Instruction *> *dstLoads) {
dstLoads->clear();
SmallVector<OperationInst *, 4> srcLoadsToKeep;
SmallVector<Instruction *, 4> srcLoadsToKeep;
for (auto *load : *srcLoads) {
if (load->cast<LoadOp>()->getMemRef() == memref)
dstLoads->push_back(load);
@ -769,7 +761,7 @@ moveLoadsAccessingMemrefTo(Value *memref,
}
// Returns the innermost common loop depth for the set of operations in 'ops'.
static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
static unsigned getInnermostCommonLoopDepth(ArrayRef<Instruction *> ops) {
unsigned numOps = ops.size();
assert(numOps > 0);
@ -797,10 +789,10 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
// and 'storeOpInsts' are satisfied.
static unsigned getMaxLoopDepth(ArrayRef<OperationInst *> loadOpInsts,
ArrayRef<OperationInst *> storeOpInsts) {
static unsigned getMaxLoopDepth(ArrayRef<Instruction *> loadOpInsts,
ArrayRef<Instruction *> storeOpInsts) {
// Merge loads and stores into the same array.
SmallVector<OperationInst *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
SmallVector<Instruction *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
ops.append(storeOpInsts.begin(), storeOpInsts.end());
// Compute the innermost common loop depth for loads and stores.
@ -913,7 +905,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
// TODO(bondhugula): consider refactoring the common code from generateDma and
// this one.
static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
OperationInst *srcStoreOpInst,
Instruction *srcStoreOpInst,
unsigned dstLoopDepth,
Optional<unsigned> fastMemorySpace,
unsigned localBufSizeThreshold) {
@ -1061,9 +1053,9 @@ static uint64_t getSliceIterationCount(
// *) Compares the total cost of the unfused loop nests to the min cost fused
// loop nest computed in the previous step, and returns true if the latter
// is lower.
static bool isFusionProfitable(OperationInst *srcOpInst,
ArrayRef<OperationInst *> dstLoadOpInsts,
ArrayRef<OperationInst *> dstStoreOpInsts,
static bool isFusionProfitable(Instruction *srcOpInst,
ArrayRef<Instruction *> dstLoadOpInsts,
ArrayRef<Instruction *> dstStoreOpInsts,
ComputationSliceState *sliceState,
unsigned *dstLoopDepth) {
LLVM_DEBUG({
@ -1174,7 +1166,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
for (auto *loadOp : dstLoadOpInsts) {
auto *parentInst = loadOp->getParentInst();
if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>())
if (parentInst && parentInst->isa<AffineForOp>())
computeCostMap[parentInst] = -1;
}
}
@ -1393,11 +1385,11 @@ public:
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>())
if (!dstNode->inst->isa<AffineForOp>())
continue;
SmallVector<OperationInst *, 4> loads = dstNode->loads;
SmallVector<OperationInst *, 4> dstLoadOpInsts;
SmallVector<Instruction *, 4> loads = dstNode->loads;
SmallVector<Instruction *, 4> dstLoadOpInsts;
DenseSet<Value *> visitedMemrefs;
while (!loads.empty()) {
// Get memref of load on top of the stack.
@ -1426,7 +1418,7 @@ public:
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
auto *srcNode = mdg->getNode(srcId);
// Skip if 'srcNode' is not a loop nest.
if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>())
if (!srcNode->inst->isa<AffineForOp>())
continue;
// Skip if 'srcNode' has more than one store to any memref.
// TODO(andydavis) Support fusing multi-output src loop nests.
@ -1454,7 +1446,7 @@ public:
// Get unique 'srcNode' store op.
auto *srcStoreOpInst = srcNode->stores.front();
// Gather 'dstNode' store ops to 'memref'.
SmallVector<OperationInst *, 2> dstStoreOpInsts;
SmallVector<Instruction *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
dstStoreOpInsts.push_back(storeOpInst);
@ -1472,8 +1464,7 @@ public:
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
auto dstAffineForOp =
cast<OperationInst>(dstNode->inst)->cast<AffineForOp>();
auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
if (insertPointInst != dstAffineForOp->getInstruction()) {
dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
}
@ -1488,7 +1479,7 @@ public:
promoteIfSingleIteration(forOp);
}
// Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<OperationInst *, 4> storesForMemref;
SmallVector<Instruction *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
storesForMemref.push_back(storeOpInst);
@ -1541,9 +1532,8 @@ public:
continue;
// Use list expected to match the dep graph info.
auto *inst = memref->getDefiningInst();
auto *opInst = dyn_cast_or_null<OperationInst>(inst);
if (opInst && opInst->isa<AllocOp>())
opInst->erase();
if (inst && inst->isa<AllocOp>())
inst->erase();
}
}
};

View File

@ -237,14 +237,13 @@ getTileableBands(Function *f,
do {
band.push_back(currInst);
} while (currInst->getBody()->getInstructions().size() == 1 &&
(currInst = cast<OperationInst>(currInst->getBody()->front())
.dyn_cast<AffineForOp>()));
(currInst = currInst->getBody()->front().dyn_cast<AffineForOp>()));
bands->push_back(band);
};
for (auto &block : *f)
for (auto &inst : block)
if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>())
if (auto forOp = inst.dyn_cast<AffineForOp>())
getMaximalPerfectLoopNest(forOp);
}

View File

@ -113,7 +113,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return hasInnerLoops;
}
bool walkPostOrder(OperationInst *opInst) {
bool walkPostOrder(Instruction *opInst) {
bool hasInnerLoops = false;
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
@ -140,7 +140,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
void visitInstruction(OperationInst *opInst) {
void visitInstruction(Instruction *opInst) {
auto forOp = opInst->dyn_cast<AffineForOp>();
if (!forOp)
return;

View File

@ -100,8 +100,7 @@ PassResult LoopUnrollAndJam::runOnFunction(Function *f) {
// any for Inst.
auto &entryBlock = f->front();
if (!entryBlock.empty())
if (auto forOp =
cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>())
if (auto forOp = entryBlock.front().dyn_cast<AffineForOp>())
runOnAffineForOp(forOp);
return success();
@ -149,12 +148,12 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
auto subBlockStart = it;
while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>())
while (it != End && !it->isa<AffineForOp>())
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
// Process all for insts that appear next.
while (it != End && cast<OperationInst>(it)->isa<AffineForOp>())
while (it != End && it->isa<AffineForOp>())
walk(&*it++);
}
}
@ -206,8 +205,7 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
// Insert the cleanup loop right after 'forOp'.
FuncBuilder builder(forInst->getBlock(),
std::next(Block::iterator(forInst)));
auto cleanupAffineForOp =
cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
auto cleanupAffineForOp = builder.clone(*forInst)->cast<AffineForOp>();
cleanupAffineForOp->setLowerBoundMap(
getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder));

View File

@ -616,23 +616,21 @@ PassResult LowerAffinePass::runOnFunction(Function *function) {
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
// We do this as a prepass to avoid invalidating the walker with our rewrite.
function->walk([&](Instruction *inst) {
auto op = cast<OperationInst>(inst);
if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() ||
op->isa<AffineIfOp>())
if (inst->isa<AffineApplyOp>() || inst->isa<AffineForOp>() ||
inst->isa<AffineIfOp>())
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
for (auto *inst : instsToRewrite) {
auto op = cast<OperationInst>(inst);
if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
if (auto ifOp = inst->dyn_cast<AffineIfOp>()) {
if (lowerAffineIf(ifOp))
return failure();
} else if (auto forOp = op->dyn_cast<AffineForOp>()) {
} else if (auto forOp = inst->dyn_cast<AffineForOp>()) {
if (lowerAffineFor(forOp))
return failure();
} else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
} else if (lowerAffineApply(inst->cast<AffineApplyOp>())) {
return failure();
}
}

View File

@ -401,13 +401,12 @@ public:
explicit VectorTransferExpander(MLIRContext *context)
: MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {}
PatternMatchResult match(OperationInst *op) const override {
PatternMatchResult match(Instruction *op) const override {
if (m_Op<VectorTransferOpTy>().match(op))
return matchSuccess();
return matchFailure();
}
void rewriteOpInst(OperationInst *op,
MLFuncGlobalLoweringState *funcWiseState,
void rewriteOpInst(Instruction *op, MLFuncGlobalLoweringState *funcWiseState,
std::unique_ptr<PatternState> opState,
MLFuncLoweringRewriter *rewriter) const override {
VectorTransferRewriter<VectorTransferOpTy>(

View File

@ -246,8 +246,8 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
return res;
}
static OperationInst *
instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
static Instruction *
instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap);
/// Not all Values belong to a program slice scoped within the immediately
@ -391,7 +391,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
/// - constant splat is replaced by constant splat of `hwVectorType`.
/// TODO(ntv): add more substitutions on a per-need basis.
static SmallVector<NamedAttribute, 1>
materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
materializeAttributes(Instruction *opInst, VectorType hwVectorType) {
SmallVector<NamedAttribute, 1> res;
for (auto a : opInst->getAttrs()) {
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
@ -411,8 +411,8 @@ materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
/// substitutionsMap.
///
/// If the underlying substitution fails, this fails too and returns nullptr.
static OperationInst *
instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
static Instruction *
instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
assert(!opInst->isa<VectorTransferReadOp>() &&
"Should call the function specialized for VectorTransferReadOp");
@ -488,7 +488,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
/// `hwVectorType` int the covering of the super-vector type. For a more
/// detailed description of the problem, see the description of
/// reindexAffineIndices.
static OperationInst *
static Instruction *
instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
ArrayRef<unsigned> hwVectorInstance,
DenseMap<const Value *, Value *> *substitutionsMap) {
@ -512,7 +512,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
/// `hwVectorType` int the covering of th3e super-vector type. For a more
/// detailed description of the problem, see the description of
/// reindexAffineIndices.
static OperationInst *
static Instruction *
instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
DenseMap<const Value *, Value *> *substitutionsMap) {
@ -555,21 +555,20 @@ static bool instantiateMaterialization(Instruction *inst,
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(inst);
auto *opInst = cast<OperationInst>(inst);
// AffineApplyOp are ignored: instantiating the proper vector op will take
// care of AffineApplyOps by composing them properly.
if (opInst->isa<AffineApplyOp>()) {
if (inst->isa<AffineApplyOp>()) {
return false;
}
if (opInst->getNumBlockLists() != 0)
if (inst->getNumBlockLists() != 0)
return inst->emitError("NYI path Op with region");
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
if (auto write = inst->dyn_cast<VectorTransferWriteOp>()) {
auto *clone = instantiate(&b, write, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
return clone == nullptr;
}
if (auto read = opInst->dyn_cast<VectorTransferReadOp>()) {
if (auto read = inst->dyn_cast<VectorTransferReadOp>()) {
auto *clone = instantiate(&b, read, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
if (!clone) {
@ -582,19 +581,19 @@ static bool instantiateMaterialization(Instruction *inst,
// The only op with 0 results reaching this point must, by construction, be
// VectorTransferWriteOps and have been caught above. Ops with >= 2 results
// are not yet supported. So just support 1 result.
if (opInst->getNumResults() != 1) {
if (inst->getNumResults() != 1) {
return inst->emitError("NYI: ops with != 1 results");
}
if (opInst->getResult(0)->getType() != state->superVectorType) {
if (inst->getResult(0)->getType() != state->superVectorType) {
return inst->emitError("Op does not return a supervector.");
}
auto *clone =
instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap);
instantiate(&b, inst, state->hwVectorType, state->substitutionsMap);
if (!clone) {
return true;
}
state->substitutionsMap->insert(
std::make_pair(opInst->getResult(0), clone->getResult(0)));
std::make_pair(inst->getResult(0), clone->getResult(0)));
return false;
}
@ -645,7 +644,7 @@ static bool emitSlice(MaterializationState *state,
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
LLVM_DEBUG(cast<OperationInst>((*slice)[0])->getFunction()->print(dbgs()));
LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs()));
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*
@ -677,7 +676,7 @@ static bool emitSlice(MaterializationState *state,
/// scope.
/// TODO(ntv): please document return value.
static bool materialize(Function *f,
const SetVector<OperationInst *> &terminators,
const SetVector<Instruction *> &terminators,
MaterializationState *state) {
DenseSet<Instruction *> seen;
DominanceInfo domInfo(f);
@ -757,18 +756,17 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
// Capture terminators; i.e. vector_transfer_write ops involving a strict
// super-vector of subVectorType.
auto filter = [subVectorType](const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
if (!opInst.isa<VectorTransferWriteOp>()) {
if (!inst.isa<VectorTransferWriteOp>()) {
return false;
}
return matcher::operatesOnSuperVectors(opInst, subVectorType);
return matcher::operatesOnSuperVectors(inst, subVectorType);
};
auto pat = Op(filter);
SmallVector<NestedMatch, 8> matches;
pat.match(f, &matches);
SetVector<OperationInst *> terminators;
SetVector<Instruction *> terminators;
for (auto m : matches) {
terminators.insert(cast<OperationInst>(m.getMatchedInstruction()));
terminators.insert(m.getMatchedInstruction());
}
auto fail = materialize(f, terminators, &state);

View File

@ -75,12 +75,12 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
PassResult runOnFunction(Function *f) override;
void visitInstruction(OperationInst *opInst);
void visitInstruction(Instruction *opInst);
// A list of memref's that are potentially dead / could be eliminated.
SmallPtrSet<Value *, 4> memrefsToErase;
// Load op's whose results were replaced by those forwarded from stores.
std::vector<OperationInst *> loadOpsToErase;
std::vector<Instruction *> loadOpsToErase;
DominanceInfo *domInfo = nullptr;
PostDominanceInfo *postDomInfo = nullptr;
@ -100,22 +100,22 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() {
// This is a straightforward implementation not optimized for speed. Optimize
// this in the future if needed.
void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
OperationInst *lastWriteStoreOp = nullptr;
void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) {
Instruction *lastWriteStoreOp = nullptr;
auto loadOp = opInst->dyn_cast<LoadOp>();
if (!loadOp)
return;
OperationInst *loadOpInst = opInst;
Instruction *loadOpInst = opInst;
// First pass over the use list to get minimum number of surrounding
// loops common between the load op and the store op, with min taken across
// all store ops.
SmallVector<OperationInst *, 8> storeOps;
SmallVector<Instruction *, 8> storeOps;
unsigned minSurroundingLoops = getNestingDepth(*loadOpInst);
for (InstOperand &use : loadOp->getMemRef()->getUses()) {
auto storeOp = cast<OperationInst>(use.getOwner())->dyn_cast<StoreOp>();
auto storeOp = use.getOwner()->dyn_cast<StoreOp>();
if (!storeOp)
continue;
auto *storeOpInst = storeOp->getInstruction();
@ -131,11 +131,11 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
// and loadOp.
// The list of store op candidates for forwarding - need to satisfy the
// conditions listed at the top.
SmallVector<OperationInst *, 8> fwdingCandidates;
SmallVector<Instruction *, 8> fwdingCandidates;
// Store ops that have a dependence into the load (even if they aren't
// forwarding candidates). Each forwarding candidate will be checked for a
// post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
SmallVector<OperationInst *, 8> depSrcStores;
SmallVector<Instruction *, 8> depSrcStores;
for (auto *storeOpInst : storeOps) {
MemRefAccess srcAccess(storeOpInst);
MemRefAccess destAccess(loadOpInst);
@ -197,7 +197,7 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
// that postdominates all 'depSrcStores' (if such a store exists) is the
// unique store providing the value to the load, i.e., provably the last
// writer to that memref loc.
if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) {
if (llvm::all_of(depSrcStores, [&](Instruction *depStore) {
return postDomInfo->postDominates(storeOpInst, depStore);
})) {
lastWriteStoreOp = storeOpInst;
@ -246,24 +246,22 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) {
// to do this as well, but we'll do it here since we collected these anyway.
for (auto *memref : memrefsToErase) {
// If the memref hasn't been alloc'ed in this function, skip.
OperationInst *defInst = memref->getDefiningInst();
Instruction *defInst = memref->getDefiningInst();
if (!defInst || !defInst->isa<AllocOp>())
// TODO(mlir-team): if the memref was returned by a 'call' instruction, we
// could still erase it if the call had no side-effects.
continue;
if (std::any_of(memref->use_begin(), memref->use_end(),
[&](InstOperand &use) {
auto *ownerInst = cast<OperationInst>(use.getOwner());
auto *ownerInst = use.getOwner();
return (!ownerInst->isa<StoreOp>() &&
!ownerInst->isa<DeallocOp>());
}))
continue;
// Erase all stores, the dealloc, and the alloc on the memref.
for (auto it = memref->use_begin(), e = memref->use_end(); it != e;) {
auto &use = *(it++);
cast<OperationInst>(use.getOwner())->erase();
}
for (auto &use : llvm::make_early_inc_range(memref->getUses()))
use.getOwner()->erase();
defInst->erase();
}

View File

@ -61,7 +61,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
// Returns the position of the tag memref operand given a DMA instruction.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
static unsigned getTagMemRefPos(const Instruction &dmaInst) {
assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>());
if (dmaInst.isa<DmaStartOp>()) {
// Second to last operand.
@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forOps.clear();
f->walkPostOrder([&](OperationInst *opInst) {
f->walkPostOrder([&](Instruction *opInst) {
if (auto forOp = opInst->dyn_cast<AffineForOp>())
forOps.push_back(forOp);
});
@ -180,33 +180,26 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
// Identify matching DMA start/finish instructions to overlap computation with.
static void findMatchingStartFinishInsts(
OpPointer<AffineForOp> forOp,
SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
&startWaitPairs) {
SmallVectorImpl<std::pair<Instruction *, Instruction *>> &startWaitPairs) {
// Collect outgoing DMA instructions - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
for (auto &inst : *forOp->getBody()) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst)
continue;
OpPointer<DmaStartOp> dmaStartOp;
if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) &&
if ((dmaStartOp = inst.dyn_cast<DmaStartOp>()) &&
dmaStartOp->isSrcMemorySpaceFaster())
outgoingDmaOps.push_back(dmaStartOp);
}
SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
SmallVector<Instruction *, 4> dmaStartInsts, dmaFinishInsts;
for (auto &inst : *forOp->getBody()) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst)
continue;
// Collect DMA finish instructions.
if (opInst->isa<DmaWaitOp>()) {
dmaFinishInsts.push_back(opInst);
if (inst.isa<DmaWaitOp>()) {
dmaFinishInsts.push_back(&inst);
continue;
}
OpPointer<DmaStartOp> dmaStartOp;
if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>()))
if (!(dmaStartOp = inst.dyn_cast<DmaStartOp>()))
continue;
// Only DMAs incoming into higher memory spaces are pipelined for now.
// TODO(bondhugula): handle outgoing DMA pipelining.
@ -236,7 +229,7 @@ static void findMatchingStartFinishInsts(
}
}
if (!escapingUses)
dmaStartInsts.push_back(opInst);
dmaStartInsts.push_back(&inst);
}
// For each start instruction, we look for a matching finish instruction.
@ -262,7 +255,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
return success();
}
SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
SmallVector<std::pair<Instruction *, Instruction *>, 4> startWaitPairs;
findMatchingStartFinishInsts(forOp, startWaitPairs);
if (startWaitPairs.empty()) {
@ -335,7 +328,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
SmallVector<OperationInst *, 4> affineApplyInsts;
SmallVector<Instruction *, 4> affineApplyInsts;
SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
getReachableAffineApplyOps(operands, affineApplyInsts);
for (const auto *inst : affineApplyInsts) {
@ -356,13 +349,13 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
for (auto &inst : *forOp->getBody()) {
assert(instShiftMap.find(&inst) != instShiftMap.end());
shifts[s++] = instShiftMap[&inst];
LLVM_DEBUG(
// Tagging instructions with shifts for debugging purposes.
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
FuncBuilder b(opInst);
opInst->setAttr(b.getIdentifier("shift"),
b.getI64IntegerAttr(shifts[s - 1]));
});
// Tagging instructions with shifts for debugging purposes.
LLVM_DEBUG({
FuncBuilder b(&inst);
inst.setAttr(b.getIdentifier("shift"),
b.getI64IntegerAttr(shifts[s - 1]));
});
}
if (!isInstwiseShiftValid(forOp, shifts)) {

View File

@ -64,7 +64,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
}
PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
f->walk([&](OperationInst *opInst) {
f->walk([&](Instruction *opInst) {
for (auto attr : opInst->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr.getValue());

View File

@ -38,13 +38,13 @@ public:
worklist.reserve(64);
// Add all operations to the worklist.
fn->walk([&](OperationInst *inst) { addToWorklist(inst); });
fn->walk([&](Instruction *inst) { addToWorklist(inst); });
}
/// Perform the rewrites.
void simplifyFunction();
void addToWorklist(OperationInst *op) {
void addToWorklist(Instruction *op) {
// Check to see if the worklist already contains this op.
if (worklistMap.count(op))
return;
@ -53,7 +53,7 @@ public:
worklist.push_back(op);
}
OperationInst *popFromWorklist() {
Instruction *popFromWorklist() {
auto *op = worklist.back();
worklist.pop_back();
@ -65,7 +65,7 @@ public:
/// If the specified operation is in the worklist, remove it. If not, this is
/// a no-op.
void removeFromWorklist(OperationInst *op) {
void removeFromWorklist(Instruction *op) {
auto it = worklistMap.find(op);
if (it != worklistMap.end()) {
assert(worklist[it->second] == op && "malformed worklist data structure");
@ -77,7 +77,7 @@ public:
protected:
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
OperationInst *createOperation(const OperationState &state) override {
Instruction *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
addToWorklist(result);
return result;
@ -85,20 +85,18 @@ protected:
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
void notifyOperationRemoved(OperationInst *op) override {
void notifyOperationRemoved(Instruction *op) override {
removeFromWorklist(op);
}
// When the root of a pattern is about to be replaced, it can trigger
// simplifications to its users - make sure to add them to the worklist
// before the root is changed.
void notifyRootReplaced(OperationInst *op) override {
void notifyRootReplaced(Instruction *op) override {
for (auto *result : op->getResults())
// TODO: Add a result->getUsers() iterator.
for (auto &user : result->getUses()) {
if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
addToWorklist(op);
}
for (auto &user : result->getUses())
addToWorklist(user.getOwner());
// TODO: Walk the operand list dropping them as we go. If any of them
// drop to zero uses, then add them to the worklist to allow them to be
@ -116,13 +114,13 @@ private:
/// need to be revisited, plus their index in the worklist. This allows us to
/// efficiently remove operations from the worklist when they are erased from
/// the function, even if they aren't the root of a pattern.
std::vector<OperationInst *> worklist;
DenseMap<OperationInst *, unsigned> worklistMap;
std::vector<Instruction *> worklist;
DenseMap<Instruction *, unsigned> worklistMap;
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
DenseMap<std::pair<Attribute, Type>, Instruction *> uniquedConstants;
};
}; // end anonymous namespace
@ -229,10 +227,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
// revisit them.
//
// TODO: Add a result->getUsers() iterator.
for (auto &operand : op->getResult(i)->getUses()) {
if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
addToWorklist(op);
}
for (auto &operand : op->getResult(i)->getUses())
addToWorklist(operand.getOwner());
res->replaceAllUsesWith(cstValue);
}
@ -267,10 +263,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
if (res->use_empty()) // ignore dead uses.
continue;
for (auto &operand : op->getResult(i)->getUses()) {
if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
addToWorklist(op);
}
for (auto &operand : op->getResult(i)->getUses())
addToWorklist(operand.getOwner());
res->replaceAllUsesWith(resultValues[i]);
}
}

View File

@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
// Replaces all IV uses to its single iteration value.
auto *iv = forOp->getInductionVar();
OperationInst *forInst = forOp->getInstruction();
Instruction *forInst = forOp->getInstruction();
if (!iv->use_empty()) {
if (forOp->hasConstantLowerBound()) {
auto *mlFunc = forInst->getFunction();
@ -135,7 +135,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
/// their body into the containing Block.
void mlir::promoteSingleIterationLoops(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
f->walkPostOrder([](OperationInst *inst) {
f->walkPostOrder([](Instruction *inst) {
if (auto forOp = inst->dyn_cast<AffineForOp>())
promoteIfSingleIteration(forOp);
});
@ -394,11 +394,10 @@ bool mlir::loopUnrollByFactor(OpPointer<AffineForOp> forOp,
return false;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
OperationInst *forInst = forOp->getInstruction();
Instruction *forInst = forOp->getInstruction();
if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst));
auto cleanupForInst =
cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
auto cleanupForInst = builder.clone(*forInst)->cast<AffineForOp>();
auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder);
assert(clLbMap &&
"cleanup loop lower bound map for single result bound maps can "

View File

@ -37,7 +37,7 @@ using namespace mlir;
/// Return true if this operation dereferences one or more memref's.
// Temporary utility: will be replaced when this is modeled through
// side-effects/op traits. TODO(b/117228571)
static bool isMemRefDereferencingOp(const OperationInst &op) {
static bool isMemRefDereferencingOp(const Instruction &op) {
if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
op.isa<DmaWaitOp>())
return true;
@ -76,12 +76,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
std::make_unique<PostDominanceInfo>(postDomInstFilter->getFunction());
// The ops where memref replacement succeeds are replaced with new ones.
SmallVector<OperationInst *, 8> opsToErase;
SmallVector<Instruction *, 8> opsToErase;
// Walk all uses of old memref. Operation using the memref gets replaced.
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
InstOperand &use = *(it++);
auto *opInst = cast<OperationInst>(use.getOwner());
for (auto &use : llvm::make_early_inc_range(oldMemRef->getUses())) {
auto *opInst = use.getOwner();
// Skip this use if it's not dominated by domInstFilter.
if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
@ -217,8 +216,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
/// uses besides this opInst; otherwise returns the list of affine_apply
/// operations created in output argument `sliceOps`.
void mlir::createAffineComputationSlice(
OperationInst *opInst,
SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps) {
Instruction *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps) {
// Collect all operands that are results of affine apply ops.
SmallVector<Value *, 4> subOperands;
subOperands.reserve(opInst->getNumOperands());
@ -230,7 +228,7 @@ void mlir::createAffineComputationSlice(
}
// Gather sequence of AffineApplyOps reachable from 'subOperands'.
SmallVector<OperationInst *, 4> affineApplyOps;
SmallVector<Instruction *, 4> affineApplyOps;
getReachableAffineApplyOps(subOperands, affineApplyOps);
// Skip transforming if there are no affine maps to compose.
if (affineApplyOps.empty())
@ -341,8 +339,7 @@ bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) {
}
void mlir::remapFunctionAttrs(
OperationInst &op,
const DenseMap<Attribute, FunctionAttr> &remappingTable) {
Instruction &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : op.getAttrs()) {
// Do the remapping, if we got the same thing back, then it must contain
// functions that aren't getting remapped.

View File

@ -110,17 +110,13 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
// Only filter instructions that operate on a strict super-vector and have one
// return. This makes testing easier.
auto filter = [subVectorType](const Instruction &inst) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst) {
return false;
}
assert(subVectorType.getElementType() ==
Type::getF32(subVectorType.getContext()) &&
"Only f32 supported for now");
if (!matcher::operatesOnSuperVectors(*opInst, subVectorType)) {
if (!matcher::operatesOnSuperVectors(inst, subVectorType)) {
return false;
}
if (opInst->getNumResults() != 1) {
if (inst.getNumResults() != 1) {
return false;
}
return true;
@ -129,7 +125,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
SmallVector<NestedMatch, 8> matches;
pat.match(f, &matches);
for (auto m : matches) {
auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
auto *opInst = m.getMatchedInstruction();
// This is a unit test that only checks and prints shape ratio.
// As a consequence we write only Ops with a single return type for the
// purpose of this test. If we need to test more intricate behavior in the
@ -159,8 +155,7 @@ static NestedPattern patternTestSlicingOps() {
using matcher::Op;
// Match all OpInstructions with the kTestSlicingOpName name.
auto filter = [](const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
return opInst.getName().getStringRef() == kTestSlicingOpName;
return inst.getName().getStringRef() == kTestSlicingOpName;
};
return Op(filter);
}
@ -209,8 +204,7 @@ void VectorizerTestPass::testSlicing(Function *f) {
}
static bool customOpWithAffineMapAttribute(const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
return opInst.getName().getStringRef() ==
return inst.getName().getStringRef() ==
VectorizerTestPass::kTestAffineMapOpName;
}
@ -222,7 +216,7 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
SmallVector<AffineMap, 4> maps;
maps.reserve(matches.size());
for (auto m : llvm::reverse(matches)) {
auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
auto *opInst = m.getMatchedInstruction();
auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
.cast<AffineMapAttr>()
.getValue();
@ -236,13 +230,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
}
static bool affineApplyOp(const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
return opInst.isa<AffineApplyOp>();
return inst.isa<AffineApplyOp>();
}
static bool singleResultAffineApplyOpWithoutUses(const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
auto app = opInst.dyn_cast<AffineApplyOp>();
auto app = inst.dyn_cast<AffineApplyOp>();
return app && app->use_empty();
}
@ -259,8 +251,7 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) {
SmallVector<NestedMatch, 8> matches;
pattern.match(f, &matches);
for (auto m : matches) {
auto app =
cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>();
auto app = m.getMatchedInstruction()->cast<AffineApplyOp>();
FuncBuilder b(m.getMatchedInstruction());
SmallVector<Value *, 8> operands(app->getOperands());
makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands);

View File

@ -723,22 +723,22 @@ namespace {
struct VectorizationState {
/// Adds an entry of pre/post vectorization instructions in the state.
void registerReplacement(OperationInst *key, OperationInst *value);
void registerReplacement(Instruction *key, Instruction *value);
/// When the current vectorization pattern is successful, this erases the
/// instructions that were marked for erasure in the proper order and resets
/// the internal state for the next pattern.
void finishVectorizationPattern();
// In-order tracking of original OperationInst that have been vectorized.
// In-order tracking of original Instruction that have been vectorized.
// Erase in reverse order.
SmallVector<OperationInst *, 16> toErase;
// Set of OperationInst that have been vectorized (the values in the
SmallVector<Instruction *, 16> toErase;
// Set of Instruction that have been vectorized (the values in the
// vectorizationMap for hashed access). The vectorizedSet is used in
// particular to filter the instructions that have already been vectorized by
// this pattern, when iterating over nested loops in this pattern.
DenseSet<OperationInst *> vectorizedSet;
// Map of old scalar OperationInst to new vectorized OperationInst.
DenseMap<OperationInst *, OperationInst *> vectorizationMap;
DenseSet<Instruction *> vectorizedSet;
// Map of old scalar Instruction to new vectorized Instruction.
DenseMap<Instruction *, Instruction *> vectorizationMap;
// Map of old scalar Value to new vectorized Value.
DenseMap<const Value *, Value *> replacementMap;
// The strategy drives which loop to vectorize by which amount.
@ -747,17 +747,17 @@ struct VectorizationState {
// vectorizeOperations function. They consist of the subset of load operations
// that have been vectorized. They can be retrieved from `vectorizationMap`
// but it is convenient to keep track of them in a separate data structure.
DenseSet<OperationInst *> roots;
DenseSet<Instruction *> roots;
// Terminator instructions for the worklist in the vectorizeOperations
// function. They consist of the subset of store operations that have been
// vectorized. They can be retrieved from `vectorizationMap` but it is
// convenient to keep track of them in a separate data structure. Since they
// do not necessarily belong to use-def chains starting from loads (e.g
// storing a constant), we need to handle them in a post-pass.
DenseSet<OperationInst *> terminators;
DenseSet<Instruction *> terminators;
// Checks that the type of `inst` is StoreOp and adds it to the terminators
// set.
void registerTerminator(OperationInst *inst);
void registerTerminator(Instruction *inst);
private:
void registerReplacement(const Value *key, Value *value);
@ -765,8 +765,8 @@ private:
} // end namespace
void VectorizationState::registerReplacement(OperationInst *key,
OperationInst *value) {
void VectorizationState::registerReplacement(Instruction *key,
Instruction *value) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: ");
LLVM_DEBUG(key->print(dbgs()));
LLVM_DEBUG(dbgs() << " into ");
@ -785,7 +785,7 @@ void VectorizationState::registerReplacement(OperationInst *key,
}
}
void VectorizationState::registerTerminator(OperationInst *inst) {
void VectorizationState::registerTerminator(Instruction *inst) {
assert(inst->isa<StoreOp>() && "terminator must be a StoreOp");
assert(terminators.count(inst) == 0 &&
"terminator was already inserted previously");
@ -867,17 +867,16 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step,
if (!matcher::isLoadOrStore(inst)) {
return false;
}
auto *opInst = cast<OperationInst>(&inst);
return state->vectorizationMap.count(opInst) == 0 &&
state->vectorizedSet.count(opInst) == 0 &&
state->roots.count(opInst) == 0 &&
state->terminators.count(opInst) == 0;
return state->vectorizationMap.count(&inst) == 0 &&
state->vectorizedSet.count(&inst) == 0 &&
state->roots.count(&inst) == 0 &&
state->terminators.count(&inst) == 0;
};
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
SmallVector<NestedMatch, 8> loadAndStoresMatches;
loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches);
for (auto ls : loadAndStoresMatches) {
auto *opInst = cast<OperationInst>(ls.getMatchedInstruction());
auto *opInst = ls.getMatchedInstruction();
auto load = opInst->dyn_cast<LoadOp>();
auto store = opInst->dyn_cast<StoreOp>();
LLVM_DEBUG(opInst->print(dbgs()));
@ -900,7 +899,7 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step,
static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
return [fastestVaryingMemRefDimension](const Instruction &forInst) {
auto loop = cast<OperationInst>(forInst).cast<AffineForOp>();
auto loop = forInst.cast<AffineForOp>();
return isVectorizableLoopAlongFastestVaryingMemRefDim(
loop, fastestVaryingMemRefDimension);
};
@ -915,7 +914,7 @@ static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
/// recursively in DFS post-order.
static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
auto *loopInst = oneMatch.getMatchedInstruction();
auto loop = cast<OperationInst>(loopInst)->cast<AffineForOp>();
auto loop = loopInst->cast<AffineForOp>();
auto childrenMatches = oneMatch.getMatchedChildren();
// 1. DFS postorder recursion, if any of my children fails, I fail too.
@ -977,15 +976,14 @@ static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant,
Location loc = inst->getLoc();
auto vectorType = type.cast<VectorType>();
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
auto *constantOpInst = cast<OperationInst>(constant.getInstruction());
auto *constantOpInst = constant.getInstruction();
OperationState state(
b.getContext(), loc, constantOpInst->getName().getStringRef(), {},
{vectorType},
{make_pair(Identifier::get("value", b.getContext()), attr)});
auto *splat = cast<OperationInst>(b.createOperation(state));
return splat->getResult(0);
return b.createOperation(state)->getResult(0);
}
/// Returns a uniqu'ed VectorType.
@ -997,8 +995,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
if (!VectorType::isValidElementType(v->getType())) {
return Type();
}
auto *definingOpInst = cast<OperationInst>(v->getDefiningInst());
if (state.vectorizedSet.count(definingOpInst) > 0) {
if (state.vectorizedSet.count(v->getDefiningInst()) > 0) {
return v->getType().cast<VectorType>();
}
return VectorType::get(state.strategy->vectorSizes, v->getType());
@ -1029,9 +1026,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
VectorizationState *state) {
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
LLVM_DEBUG(operand->print(dbgs()));
auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst());
// 1. If this value has already been vectorized this round, we are done.
if (state->vectorizedSet.count(definingInstruction) > 0) {
if (state->vectorizedSet.count(operand->getDefiningInst()) > 0) {
LLVM_DEBUG(dbgs() << " -> already vector operand");
return operand;
}
@ -1062,7 +1058,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
return nullptr;
};
/// Encodes OperationInst-specific behavior for vectorization. In general we
/// Encodes Instruction-specific behavior for vectorization. In general we
/// assume that all operands of an op must be vectorized but this is not always
/// true. In the future, it would be nice to have a trait that describes how a
/// particular operation vectorizes. For now we implement the case distinction
@ -1071,9 +1067,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
/// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized.
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
/// do one-off logic here; ideally it would be TableGen'd.
static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
OperationInst *opInst,
VectorizationState *state) {
static Instruction *vectorizeOneInstruction(FuncBuilder *b, Instruction *opInst,
VectorizationState *state) {
// Sanity checks.
assert(!opInst->isa<LoadOp>() &&
"all loads must have already been fully vectorized independently");
@ -1094,7 +1089,7 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<VectorTransferWriteOp>(
opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
auto *res = cast<OperationInst>(transfer->getInstruction());
auto *res = transfer->getInstruction();
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
// "Terminators" (i.e. StoreOps) are erased on the spot.
opInst->erase();
@ -1119,8 +1114,8 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
// Create a clone of the op with the proper operands and return types.
// TODO(ntv): The following assumes there is always an op with a fixed
// name that works both in scalar mode and vector mode.
// TODO(ntv): Is it worth considering an OperationInst.clone operation
// which changes the type so we can promote an OperationInst with less
// TODO(ntv): Is it worth considering an Instruction.clone operation
// which changes the type so we can promote an Instruction with less
// boilerplate?
OperationState newOp(b->getContext(), opInst->getLoc(),
opInst->getName().getStringRef(), operands, types,
@ -1129,22 +1124,22 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
return b->createOperation(newOp);
}
/// Iterates over the OperationInst in the loop and rewrites them using their
/// Iterates over the Instruction in the loop and rewrites them using their
/// vectorized counterpart by:
/// 1. iteratively building a worklist of uses of the OperationInst vectorized
/// 1. iteratively building a worklist of uses of the Instruction vectorized
/// so far by this pattern;
/// 2. for each OperationInst in the worklist, create the vector form of this
/// 2. for each Instruction in the worklist, create the vector form of this
/// operation and replace all its uses by the vectorized form. For this step,
/// the worklist must be traversed in order;
/// 3. verify that all operands of the newly vectorized operation have been
/// vectorized by this pattern.
static bool vectorizeOperations(VectorizationState *state) {
// 1. create initial worklist with the uses of the roots.
SetVector<OperationInst *> worklist;
auto insertUsesOf = [&worklist, state](OperationInst *vectorized) {
SetVector<Instruction *> worklist;
auto insertUsesOf = [&worklist, state](Instruction *vectorized) {
for (auto *r : vectorized->getResults())
for (auto &u : r->getUses()) {
auto *inst = cast<OperationInst>(u.getOwner());
auto *inst = u.getOwner();
// Don't propagate to terminals, a separate pass is needed for those.
// TODO(ntv)[b/119759136]: use isa<> once Op is implemented.
if (state->terminators.count(inst) > 0) {
@ -1166,7 +1161,7 @@ static bool vectorizeOperations(VectorizationState *state) {
// 2. Create vectorized form of the instruction.
// Insert it just before inst, on success register inst as replaced.
FuncBuilder b(inst);
auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state);
auto *vectorizedInst = vectorizeOneInstruction(&b, inst, state);
if (!vectorizedInst) {
return true;
}
@ -1179,7 +1174,7 @@ static bool vectorizeOperations(VectorizationState *state) {
// 4. Augment the worklist with uses of the instruction we just vectorized.
// This preserves the proper order in the worklist.
apply(insertUsesOf, ArrayRef<OperationInst *>{inst});
apply(insertUsesOf, ArrayRef<Instruction *>{inst});
}
return false;
}
@ -1189,8 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) {
/// Each root may succeed independently but will otherwise clean after itself if
/// anything below it fails.
static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
auto loop =
cast<OperationInst>(m.getMatchedInstruction())->cast<AffineForOp>();
auto loop = m.getMatchedInstruction()->cast<AffineForOp>();
VectorizationState state;
state.strategy = strategy;
@ -1207,8 +1201,7 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
}
auto *loopInst = loop->getInstruction();
FuncBuilder builder(loopInst);
auto clonedLoop =
cast<OperationInst>(builder.clone(*loopInst))->cast<AffineForOp>();
auto clonedLoop = builder.clone(*loopInst)->cast<AffineForOp>();
auto fail = doVectorize(m, &state);
/// Sets up error handling for this root loop. This is how the root match
@ -1248,12 +1241,12 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
}
// Finally, vectorize the terminators. If anything fails to vectorize, skip.
auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
auto vectorizeOrFail = [&fail, &state](Instruction *inst) {
if (fail) {
return;
}
FuncBuilder b(inst);
auto *res = vectorizeOneOperationInst(&b, inst, &state);
auto *res = vectorizeOneInstruction(&b, inst, &state);
if (res == nullptr) {
fail = true;
}