mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-22 08:32:23 +00:00
Remove remaining usages of OperationInst in lib/Transforms.
PiperOrigin-RevId: 232323671
This commit is contained in:
parent
44e040dd63
commit
b499277fb6
@ -39,10 +39,10 @@ using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// TODO(riverriddle) Handle commutative operations.
|
||||
struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> {
|
||||
static unsigned getHashValue(const OperationInst *op) {
|
||||
struct SimpleOperationInfo : public llvm::DenseMapInfo<Instruction *> {
|
||||
static unsigned getHashValue(const Instruction *op) {
|
||||
// Hash the operations based upon their:
|
||||
// - OperationInst Name
|
||||
// - Instruction Name
|
||||
// - Attributes
|
||||
// - Result Types
|
||||
// - Operands
|
||||
@ -51,7 +51,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> {
|
||||
hash_combine_range(op->result_type_begin(), op->result_type_end()),
|
||||
hash_combine_range(op->operand_begin(), op->operand_end()));
|
||||
}
|
||||
static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) {
|
||||
static bool isEqual(const Instruction *lhs, const Instruction *rhs) {
|
||||
if (lhs == rhs)
|
||||
return true;
|
||||
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
|
||||
@ -89,8 +89,8 @@ struct CSE : public FunctionPass {
|
||||
/// Shared implementation of operation elimination and scoped map definitions.
|
||||
using AllocatorTy = llvm::RecyclingAllocator<
|
||||
llvm::BumpPtrAllocator,
|
||||
llvm::ScopedHashTableVal<OperationInst *, OperationInst *>>;
|
||||
using ScopedMapTy = llvm::ScopedHashTable<OperationInst *, OperationInst *,
|
||||
llvm::ScopedHashTableVal<Instruction *, Instruction *>>;
|
||||
using ScopedMapTy = llvm::ScopedHashTable<Instruction *, Instruction *,
|
||||
SimpleOperationInfo, AllocatorTy>;
|
||||
|
||||
/// Represents a single entry in the depth first traversal of a CFG.
|
||||
@ -111,7 +111,7 @@ struct CSE : public FunctionPass {
|
||||
|
||||
/// Attempt to eliminate a redundant operation. Returns true if the operation
|
||||
/// was marked for removal, false otherwise.
|
||||
bool simplifyOperation(OperationInst *op);
|
||||
bool simplifyOperation(Instruction *op);
|
||||
|
||||
void simplifyBlock(Block *bb);
|
||||
|
||||
@ -122,14 +122,14 @@ private:
|
||||
ScopedMapTy knownValues;
|
||||
|
||||
/// Operations marked as dead and to be erased.
|
||||
std::vector<OperationInst *> opsToErase;
|
||||
std::vector<Instruction *> opsToErase;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
char CSE::passID = 0;
|
||||
|
||||
/// Attempt to eliminate a redundant operation.
|
||||
bool CSE::simplifyOperation(OperationInst *op) {
|
||||
bool CSE::simplifyOperation(Instruction *op) {
|
||||
// TODO(riverriddle) We currently only eliminate non side-effecting
|
||||
// operations.
|
||||
if (!op->hasNoSideEffect())
|
||||
@ -166,23 +166,16 @@ bool CSE::simplifyOperation(OperationInst *op) {
|
||||
|
||||
void CSE::simplifyBlock(Block *bb) {
|
||||
for (auto &i : *bb) {
|
||||
switch (i.getKind()) {
|
||||
case Instruction::Kind::OperationInst: {
|
||||
auto *opInst = cast<OperationInst>(&i);
|
||||
// If the operation is simplified, we don't process any held block lists.
|
||||
if (simplifyOperation(&i))
|
||||
continue;
|
||||
|
||||
// If the operation is simplified, we don't process any held block lists.
|
||||
if (simplifyOperation(opInst))
|
||||
continue;
|
||||
|
||||
// Simplify any held blocks.
|
||||
for (auto &blockList : opInst->getBlockLists()) {
|
||||
for (auto &b : blockList) {
|
||||
ScopedMapTy::ScopeTy scope(knownValues);
|
||||
simplifyBlock(&b);
|
||||
}
|
||||
// Simplify any held blocks.
|
||||
for (auto &blockList : i.getBlockLists()) {
|
||||
for (auto &b : blockList) {
|
||||
ScopedMapTy::ScopeTy scope(knownValues);
|
||||
simplifyBlock(&b);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ namespace {
|
||||
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
|
||||
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
void visitInstruction(OperationInst *opInst);
|
||||
void visitInstruction(Instruction *opInst);
|
||||
|
||||
SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps;
|
||||
|
||||
@ -64,14 +64,12 @@ FunctionPass *mlir::createComposeAffineMapsPass() {
|
||||
}
|
||||
|
||||
static bool affineApplyOp(const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
return opInst.isa<AffineApplyOp>();
|
||||
return inst.isa<AffineApplyOp>();
|
||||
}
|
||||
|
||||
void ComposeAffineMaps::visitInstruction(OperationInst *opInst) {
|
||||
if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) {
|
||||
void ComposeAffineMaps::visitInstruction(Instruction *opInst) {
|
||||
if (auto afOp = opInst->dyn_cast<AffineApplyOp>())
|
||||
affineApplyOps.push_back(afOp);
|
||||
}
|
||||
}
|
||||
|
||||
PassResult ComposeAffineMaps::runOnFunction(Function *f) {
|
||||
|
@ -33,11 +33,11 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
|
||||
// All constants in the function post folding.
|
||||
SmallVector<Value *, 8> existingConstants;
|
||||
// Operations that were folded and that need to be erased.
|
||||
std::vector<OperationInst *> opInstsToErase;
|
||||
std::vector<Instruction *> opInstsToErase;
|
||||
|
||||
bool foldOperation(OperationInst *op,
|
||||
bool foldOperation(Instruction *op,
|
||||
SmallVectorImpl<Value *> &existingConstants);
|
||||
void visitInstruction(OperationInst *op);
|
||||
void visitInstruction(Instruction *op);
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
static char passID;
|
||||
@ -49,7 +49,7 @@ char ConstantFold::passID = 0;
|
||||
/// Attempt to fold the specified operation, updating the IR to match. If
|
||||
/// constants are found, we keep track of them in the existingConstants list.
|
||||
///
|
||||
void ConstantFold::visitInstruction(OperationInst *op) {
|
||||
void ConstantFold::visitInstruction(Instruction *op) {
|
||||
// If this operation is an AffineForOp, then fold the bounds.
|
||||
if (auto forOp = op->dyn_cast<AffineForOp>()) {
|
||||
constantFoldBounds(forOp);
|
||||
|
@ -50,7 +50,7 @@ private:
|
||||
// Utility that looks up a list of value in the value remapping table. Returns
|
||||
// an empty vector if one of the values is not mapped yet.
|
||||
SmallVector<Value *, 4>
|
||||
lookupValues(const llvm::iterator_range<OperationInst::const_operand_iterator>
|
||||
lookupValues(const llvm::iterator_range<Instruction::const_operand_iterator>
|
||||
&operands);
|
||||
|
||||
// Converts the given function to the dialect using hooks defined in
|
||||
@ -61,13 +61,13 @@ private:
|
||||
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
|
||||
// passes them to `converter->rewriteTerminator` function defined in the
|
||||
// pattern, together with `builder`.
|
||||
bool convertOpWithSuccessors(DialectOpConversion *converter,
|
||||
OperationInst *op, FuncBuilder &builder);
|
||||
bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op,
|
||||
FuncBuilder &builder);
|
||||
|
||||
// Converts an operation without successors. Extracts the converted operands
|
||||
// from `valueRemapping` and passes them to the `converter->rewrite` function
|
||||
// defined in the pattern, together with `builder`.
|
||||
bool convertOp(DialectOpConversion *converter, OperationInst *op,
|
||||
bool convertOp(DialectOpConversion *converter, Instruction *op,
|
||||
FuncBuilder &builder);
|
||||
|
||||
// Converts a block by traversing its instructions sequentially, looking for
|
||||
@ -104,8 +104,7 @@ private:
|
||||
} // end namespace mlir
|
||||
|
||||
SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
|
||||
const llvm::iterator_range<OperationInst::const_operand_iterator>
|
||||
&operands) {
|
||||
const llvm::iterator_range<Instruction::const_operand_iterator> &operands) {
|
||||
SmallVector<Value *, 4> remapped;
|
||||
remapped.reserve(llvm::size(operands));
|
||||
for (const Value *operand : operands) {
|
||||
@ -118,7 +117,7 @@ SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
|
||||
}
|
||||
|
||||
bool impl::FunctionConversion::convertOpWithSuccessors(
|
||||
DialectOpConversion *converter, OperationInst *op, FuncBuilder &builder) {
|
||||
DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) {
|
||||
SmallVector<Block *, 2> destinations;
|
||||
destinations.reserve(op->getNumSuccessors());
|
||||
SmallVector<Value *, 4> operands = lookupValues(op->getOperands());
|
||||
@ -149,7 +148,7 @@ bool impl::FunctionConversion::convertOpWithSuccessors(
|
||||
}
|
||||
|
||||
bool impl::FunctionConversion::convertOp(DialectOpConversion *converter,
|
||||
OperationInst *op,
|
||||
Instruction *op,
|
||||
FuncBuilder &builder) {
|
||||
auto operands = lookupValues(op->getOperands());
|
||||
assert((!operands.empty() || op->getNumOperands() == 0) &&
|
||||
@ -174,24 +173,22 @@ bool impl::FunctionConversion::convertBlock(
|
||||
|
||||
// Iterate over ops and convert them.
|
||||
for (Instruction &inst : *block) {
|
||||
auto op = dyn_cast<OperationInst>(&inst);
|
||||
if (!op) {
|
||||
inst.emitError("unsupported instruction (For/If)");
|
||||
if (inst.getNumBlockLists() != 0) {
|
||||
inst.emitError("unsupported region instruction");
|
||||
return true;
|
||||
}
|
||||
|
||||
// Find the first matching conversion and apply it.
|
||||
bool converted = false;
|
||||
for (auto *conversion : conversions) {
|
||||
if (!conversion->match(op))
|
||||
if (!conversion->match(&inst))
|
||||
continue;
|
||||
|
||||
if (op->isTerminator() && op->getNumSuccessors() > 0) {
|
||||
if (convertOpWithSuccessors(conversion, op, builder))
|
||||
return true;
|
||||
} else {
|
||||
if (convertOp(conversion, op, builder))
|
||||
if (inst.isTerminator() && inst.getNumSuccessors() > 0) {
|
||||
if (convertOpWithSuccessors(conversion, &inst, builder))
|
||||
return true;
|
||||
} else if (convertOp(conversion, &inst, builder)) {
|
||||
return true;
|
||||
}
|
||||
converted = true;
|
||||
break;
|
||||
|
@ -157,8 +157,7 @@ static void getMultiLevelStrides(const MemRefRegion ®ion,
|
||||
/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
|
||||
/// enclosing loop IVs of opInst (starting from the outermost) that the region
|
||||
/// is parametric on.
|
||||
static bool getFullMemRefAsRegion(OperationInst *opInst,
|
||||
unsigned numParamLoopIVs,
|
||||
static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs,
|
||||
MemRefRegion *region) {
|
||||
unsigned rank;
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
@ -563,7 +562,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
|
||||
fastBufferMap.clear();
|
||||
|
||||
// Walk this range of instructions to gather all memory regions.
|
||||
block->walk(begin, end, [&](OperationInst *opInst) {
|
||||
block->walk(begin, end, [&](Instruction *opInst) {
|
||||
// Gather regions to allocate to buffers in faster memory space.
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
|
||||
|
@ -114,11 +114,11 @@ namespace {
|
||||
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
|
||||
public:
|
||||
SmallVector<OpPointer<AffineForOp>, 4> forOps;
|
||||
SmallVector<OperationInst *, 4> loadOpInsts;
|
||||
SmallVector<OperationInst *, 4> storeOpInsts;
|
||||
SmallVector<Instruction *, 4> loadOpInsts;
|
||||
SmallVector<Instruction *, 4> storeOpInsts;
|
||||
bool hasNonForRegion = false;
|
||||
|
||||
void visitInstruction(OperationInst *opInst) {
|
||||
void visitInstruction(Instruction *opInst) {
|
||||
if (opInst->isa<AffineForOp>())
|
||||
forOps.push_back(opInst->cast<AffineForOp>());
|
||||
else if (opInst->getNumBlockLists() != 0)
|
||||
@ -131,7 +131,7 @@ public:
|
||||
};
|
||||
|
||||
// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
|
||||
static bool isMemRefDereferencingOp(const OperationInst &op) {
|
||||
static bool isMemRefDereferencingOp(const Instruction &op) {
|
||||
if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
|
||||
op.isa<DmaWaitOp>())
|
||||
return true;
|
||||
@ -153,9 +153,9 @@ public:
|
||||
// The top-level statment which is (or contains) loads/stores.
|
||||
Instruction *inst;
|
||||
// List of load operations.
|
||||
SmallVector<OperationInst *, 4> loads;
|
||||
SmallVector<Instruction *, 4> loads;
|
||||
// List of store op insts.
|
||||
SmallVector<OperationInst *, 4> stores;
|
||||
SmallVector<Instruction *, 4> stores;
|
||||
Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
|
||||
|
||||
// Returns the load op count for 'memref'.
|
||||
@ -258,16 +258,13 @@ public:
|
||||
for (auto *storeOpInst : node->stores) {
|
||||
auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
|
||||
auto *inst = memref->getDefiningInst();
|
||||
auto *opInst = dyn_cast_or_null<OperationInst>(inst);
|
||||
// Return false if 'memref' is a function argument.
|
||||
if (opInst == nullptr)
|
||||
// Return false if 'memref' is a block argument.
|
||||
if (!inst)
|
||||
return true;
|
||||
// Return false if any use of 'memref' escapes the function.
|
||||
for (auto &use : memref->getUses()) {
|
||||
auto *user = dyn_cast<OperationInst>(use.getOwner());
|
||||
if (!user || !isMemRefDereferencingOp(*user))
|
||||
for (auto &use : memref->getUses())
|
||||
if (!isMemRefDereferencingOp(*use.getOwner()))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -461,8 +458,8 @@ public:
|
||||
}
|
||||
|
||||
// Adds ops in 'loads' and 'stores' to node at 'id'.
|
||||
void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
|
||||
const SmallVectorImpl<OperationInst *> &stores) {
|
||||
void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
|
||||
const SmallVectorImpl<Instruction *> &stores) {
|
||||
Node *node = getNode(id);
|
||||
for (auto *loadOpInst : loads)
|
||||
node->loads.push_back(loadOpInst);
|
||||
@ -509,7 +506,7 @@ bool MemRefDependenceGraph::init(Function *f) {
|
||||
|
||||
DenseMap<Instruction *, unsigned> forToNodeMap;
|
||||
for (auto &inst : f->front()) {
|
||||
if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) {
|
||||
if (auto forOp = inst.dyn_cast<AffineForOp>()) {
|
||||
// Create graph node 'id' to represent top-level 'forOp' and record
|
||||
// all loads and store accesses it contains.
|
||||
LoopNestStateCollector collector;
|
||||
@ -530,30 +527,28 @@ bool MemRefDependenceGraph::init(Function *f) {
|
||||
}
|
||||
forToNodeMap[&inst] = node.id;
|
||||
nodes.insert({node.id, node});
|
||||
} else if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
// Create graph node for top-level load op.
|
||||
Node node(nextNodeId++, &inst);
|
||||
node.loads.push_back(opInst);
|
||||
auto *memref = opInst->cast<LoadOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
|
||||
// Create graph node for top-level store op.
|
||||
Node node(nextNodeId++, &inst);
|
||||
node.stores.push_back(opInst);
|
||||
auto *memref = opInst->cast<StoreOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
} else if (opInst->getNumBlockLists() != 0) {
|
||||
// Return false if another region is found (not currently supported).
|
||||
return false;
|
||||
} else if (opInst->getNumResults() > 0 && !opInst->use_empty()) {
|
||||
// Create graph node for top-level producer of SSA values, which
|
||||
// could be used by loop nest nodes.
|
||||
Node node(nextNodeId++, &inst);
|
||||
nodes.insert({node.id, node});
|
||||
}
|
||||
} else if (auto loadOp = inst.dyn_cast<LoadOp>()) {
|
||||
// Create graph node for top-level load op.
|
||||
Node node(nextNodeId++, &inst);
|
||||
node.loads.push_back(&inst);
|
||||
auto *memref = inst.cast<LoadOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
} else if (auto storeOp = inst.dyn_cast<StoreOp>()) {
|
||||
// Create graph node for top-level store op.
|
||||
Node node(nextNodeId++, &inst);
|
||||
node.stores.push_back(&inst);
|
||||
auto *memref = inst.cast<StoreOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
} else if (inst.getNumBlockLists() != 0) {
|
||||
// Return false if another region is found (not currently supported).
|
||||
return false;
|
||||
} else if (inst.getNumResults() > 0 && !inst.use_empty()) {
|
||||
// Create graph node for top-level producer of SSA values, which
|
||||
// could be used by loop nest nodes.
|
||||
Node node(nextNodeId++, &inst);
|
||||
nodes.insert({node.id, node});
|
||||
}
|
||||
}
|
||||
|
||||
@ -563,12 +558,11 @@ bool MemRefDependenceGraph::init(Function *f) {
|
||||
const Node &node = idAndNode.second;
|
||||
if (!node.loads.empty() || !node.stores.empty())
|
||||
continue;
|
||||
auto *opInst = cast<OperationInst>(node.inst);
|
||||
auto *opInst = node.inst;
|
||||
for (auto *value : opInst->getResults()) {
|
||||
for (auto &use : value->getUses()) {
|
||||
auto *userOpInst = cast<OperationInst>(use.getOwner());
|
||||
SmallVector<OpPointer<AffineForOp>, 4> loops;
|
||||
getLoopIVs(*userOpInst, &loops);
|
||||
getLoopIVs(*use.getOwner(), &loops);
|
||||
if (loops.empty())
|
||||
continue;
|
||||
assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
|
||||
@ -619,7 +613,7 @@ public:
|
||||
|
||||
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
|
||||
|
||||
void visitInstruction(OperationInst *opInst) {
|
||||
void visitInstruction(Instruction *opInst) {
|
||||
auto forOp = opInst->dyn_cast<AffineForOp>();
|
||||
if (!forOp)
|
||||
return;
|
||||
@ -627,8 +621,7 @@ public:
|
||||
auto *forInst = forOp->getInstruction();
|
||||
auto *parentInst = forOp->getInstruction()->getParentInst();
|
||||
if (parentInst != nullptr) {
|
||||
assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() &&
|
||||
"Expected parent AffineForOp");
|
||||
assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
|
||||
// Add mapping to 'forOp' from its parent AffineForOp.
|
||||
stats->loopMap[parentInst].push_back(forOp);
|
||||
}
|
||||
@ -637,8 +630,7 @@ public:
|
||||
unsigned count = 0;
|
||||
stats->opCountMap[forInst] = 0;
|
||||
for (auto &inst : *forOp->getBody()) {
|
||||
if (!(cast<OperationInst>(inst).isa<AffineForOp>() ||
|
||||
cast<OperationInst>(inst).isa<AffineIfOp>()))
|
||||
if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
|
||||
++count;
|
||||
}
|
||||
stats->opCountMap[forInst] = count;
|
||||
@ -723,7 +715,7 @@ static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
|
||||
// was encountered).
|
||||
// TODO(andydavis) Make this work with non-unit step loops.
|
||||
static bool buildSliceTripCountMap(
|
||||
OperationInst *srcOpInst, ComputationSliceState *sliceState,
|
||||
Instruction *srcOpInst, ComputationSliceState *sliceState,
|
||||
llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
|
||||
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
|
||||
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||
@ -755,10 +747,10 @@ static bool buildSliceTripCountMap(
|
||||
// adds them to 'dstLoads'.
|
||||
static void
|
||||
moveLoadsAccessingMemrefTo(Value *memref,
|
||||
SmallVectorImpl<OperationInst *> *srcLoads,
|
||||
SmallVectorImpl<OperationInst *> *dstLoads) {
|
||||
SmallVectorImpl<Instruction *> *srcLoads,
|
||||
SmallVectorImpl<Instruction *> *dstLoads) {
|
||||
dstLoads->clear();
|
||||
SmallVector<OperationInst *, 4> srcLoadsToKeep;
|
||||
SmallVector<Instruction *, 4> srcLoadsToKeep;
|
||||
for (auto *load : *srcLoads) {
|
||||
if (load->cast<LoadOp>()->getMemRef() == memref)
|
||||
dstLoads->push_back(load);
|
||||
@ -769,7 +761,7 @@ moveLoadsAccessingMemrefTo(Value *memref,
|
||||
}
|
||||
|
||||
// Returns the innermost common loop depth for the set of operations in 'ops'.
|
||||
static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
|
||||
static unsigned getInnermostCommonLoopDepth(ArrayRef<Instruction *> ops) {
|
||||
unsigned numOps = ops.size();
|
||||
assert(numOps > 0);
|
||||
|
||||
@ -797,10 +789,10 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
|
||||
|
||||
// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
|
||||
// and 'storeOpInsts' are satisfied.
|
||||
static unsigned getMaxLoopDepth(ArrayRef<OperationInst *> loadOpInsts,
|
||||
ArrayRef<OperationInst *> storeOpInsts) {
|
||||
static unsigned getMaxLoopDepth(ArrayRef<Instruction *> loadOpInsts,
|
||||
ArrayRef<Instruction *> storeOpInsts) {
|
||||
// Merge loads and stores into the same array.
|
||||
SmallVector<OperationInst *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
|
||||
SmallVector<Instruction *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
|
||||
ops.append(storeOpInsts.begin(), storeOpInsts.end());
|
||||
|
||||
// Compute the innermost common loop depth for loads and stores.
|
||||
@ -913,7 +905,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||
// TODO(bondhugula): consider refactoring the common code from generateDma and
|
||||
// this one.
|
||||
static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
|
||||
OperationInst *srcStoreOpInst,
|
||||
Instruction *srcStoreOpInst,
|
||||
unsigned dstLoopDepth,
|
||||
Optional<unsigned> fastMemorySpace,
|
||||
unsigned localBufSizeThreshold) {
|
||||
@ -1061,9 +1053,9 @@ static uint64_t getSliceIterationCount(
|
||||
// *) Compares the total cost of the unfused loop nests to the min cost fused
|
||||
// loop nest computed in the previous step, and returns true if the latter
|
||||
// is lower.
|
||||
static bool isFusionProfitable(OperationInst *srcOpInst,
|
||||
ArrayRef<OperationInst *> dstLoadOpInsts,
|
||||
ArrayRef<OperationInst *> dstStoreOpInsts,
|
||||
static bool isFusionProfitable(Instruction *srcOpInst,
|
||||
ArrayRef<Instruction *> dstLoadOpInsts,
|
||||
ArrayRef<Instruction *> dstStoreOpInsts,
|
||||
ComputationSliceState *sliceState,
|
||||
unsigned *dstLoopDepth) {
|
||||
LLVM_DEBUG({
|
||||
@ -1174,7 +1166,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
|
||||
computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
|
||||
for (auto *loadOp : dstLoadOpInsts) {
|
||||
auto *parentInst = loadOp->getParentInst();
|
||||
if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>())
|
||||
if (parentInst && parentInst->isa<AffineForOp>())
|
||||
computeCostMap[parentInst] = -1;
|
||||
}
|
||||
}
|
||||
@ -1393,11 +1385,11 @@ public:
|
||||
// Get 'dstNode' into which to attempt fusion.
|
||||
auto *dstNode = mdg->getNode(dstId);
|
||||
// Skip if 'dstNode' is not a loop nest.
|
||||
if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>())
|
||||
if (!dstNode->inst->isa<AffineForOp>())
|
||||
continue;
|
||||
|
||||
SmallVector<OperationInst *, 4> loads = dstNode->loads;
|
||||
SmallVector<OperationInst *, 4> dstLoadOpInsts;
|
||||
SmallVector<Instruction *, 4> loads = dstNode->loads;
|
||||
SmallVector<Instruction *, 4> dstLoadOpInsts;
|
||||
DenseSet<Value *> visitedMemrefs;
|
||||
while (!loads.empty()) {
|
||||
// Get memref of load on top of the stack.
|
||||
@ -1426,7 +1418,7 @@ public:
|
||||
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
|
||||
auto *srcNode = mdg->getNode(srcId);
|
||||
// Skip if 'srcNode' is not a loop nest.
|
||||
if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>())
|
||||
if (!srcNode->inst->isa<AffineForOp>())
|
||||
continue;
|
||||
// Skip if 'srcNode' has more than one store to any memref.
|
||||
// TODO(andydavis) Support fusing multi-output src loop nests.
|
||||
@ -1454,7 +1446,7 @@ public:
|
||||
// Get unique 'srcNode' store op.
|
||||
auto *srcStoreOpInst = srcNode->stores.front();
|
||||
// Gather 'dstNode' store ops to 'memref'.
|
||||
SmallVector<OperationInst *, 2> dstStoreOpInsts;
|
||||
SmallVector<Instruction *, 2> dstStoreOpInsts;
|
||||
for (auto *storeOpInst : dstNode->stores)
|
||||
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
|
||||
dstStoreOpInsts.push_back(storeOpInst);
|
||||
@ -1472,8 +1464,7 @@ public:
|
||||
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
|
||||
if (sliceLoopNest != nullptr) {
|
||||
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
|
||||
auto dstAffineForOp =
|
||||
cast<OperationInst>(dstNode->inst)->cast<AffineForOp>();
|
||||
auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
|
||||
if (insertPointInst != dstAffineForOp->getInstruction()) {
|
||||
dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
|
||||
}
|
||||
@ -1488,7 +1479,7 @@ public:
|
||||
promoteIfSingleIteration(forOp);
|
||||
}
|
||||
// Create private memref for 'memref' in 'dstAffineForOp'.
|
||||
SmallVector<OperationInst *, 4> storesForMemref;
|
||||
SmallVector<Instruction *, 4> storesForMemref;
|
||||
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
|
||||
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
|
||||
storesForMemref.push_back(storeOpInst);
|
||||
@ -1541,9 +1532,8 @@ public:
|
||||
continue;
|
||||
// Use list expected to match the dep graph info.
|
||||
auto *inst = memref->getDefiningInst();
|
||||
auto *opInst = dyn_cast_or_null<OperationInst>(inst);
|
||||
if (opInst && opInst->isa<AllocOp>())
|
||||
opInst->erase();
|
||||
if (inst && inst->isa<AllocOp>())
|
||||
inst->erase();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -237,14 +237,13 @@ getTileableBands(Function *f,
|
||||
do {
|
||||
band.push_back(currInst);
|
||||
} while (currInst->getBody()->getInstructions().size() == 1 &&
|
||||
(currInst = cast<OperationInst>(currInst->getBody()->front())
|
||||
.dyn_cast<AffineForOp>()));
|
||||
(currInst = currInst->getBody()->front().dyn_cast<AffineForOp>()));
|
||||
bands->push_back(band);
|
||||
};
|
||||
|
||||
for (auto &block : *f)
|
||||
for (auto &inst : block)
|
||||
if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>())
|
||||
if (auto forOp = inst.dyn_cast<AffineForOp>())
|
||||
getMaximalPerfectLoopNest(forOp);
|
||||
}
|
||||
|
||||
|
@ -113,7 +113,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
|
||||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
bool walkPostOrder(OperationInst *opInst) {
|
||||
bool walkPostOrder(Instruction *opInst) {
|
||||
bool hasInnerLoops = false;
|
||||
for (auto &blockList : opInst->getBlockLists())
|
||||
for (auto &block : blockList)
|
||||
@ -140,7 +140,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
|
||||
const unsigned minTripCount;
|
||||
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
|
||||
|
||||
void visitInstruction(OperationInst *opInst) {
|
||||
void visitInstruction(Instruction *opInst) {
|
||||
auto forOp = opInst->dyn_cast<AffineForOp>();
|
||||
if (!forOp)
|
||||
return;
|
||||
|
@ -100,8 +100,7 @@ PassResult LoopUnrollAndJam::runOnFunction(Function *f) {
|
||||
// any for Inst.
|
||||
auto &entryBlock = f->front();
|
||||
if (!entryBlock.empty())
|
||||
if (auto forOp =
|
||||
cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>())
|
||||
if (auto forOp = entryBlock.front().dyn_cast<AffineForOp>())
|
||||
runOnAffineForOp(forOp);
|
||||
|
||||
return success();
|
||||
@ -149,12 +148,12 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
|
||||
void walk(InstListType::iterator Start, InstListType::iterator End) {
|
||||
for (auto it = Start; it != End;) {
|
||||
auto subBlockStart = it;
|
||||
while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>())
|
||||
while (it != End && !it->isa<AffineForOp>())
|
||||
++it;
|
||||
if (it != subBlockStart)
|
||||
subBlocks.push_back({subBlockStart, std::prev(it)});
|
||||
// Process all for insts that appear next.
|
||||
while (it != End && cast<OperationInst>(it)->isa<AffineForOp>())
|
||||
while (it != End && it->isa<AffineForOp>())
|
||||
walk(&*it++);
|
||||
}
|
||||
}
|
||||
@ -206,8 +205,7 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
|
||||
// Insert the cleanup loop right after 'forOp'.
|
||||
FuncBuilder builder(forInst->getBlock(),
|
||||
std::next(Block::iterator(forInst)));
|
||||
auto cleanupAffineForOp =
|
||||
cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
|
||||
auto cleanupAffineForOp = builder.clone(*forInst)->cast<AffineForOp>();
|
||||
cleanupAffineForOp->setLowerBoundMap(
|
||||
getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder));
|
||||
|
||||
|
@ -616,23 +616,21 @@ PassResult LowerAffinePass::runOnFunction(Function *function) {
|
||||
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
|
||||
// We do this as a prepass to avoid invalidating the walker with our rewrite.
|
||||
function->walk([&](Instruction *inst) {
|
||||
auto op = cast<OperationInst>(inst);
|
||||
if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() ||
|
||||
op->isa<AffineIfOp>())
|
||||
if (inst->isa<AffineApplyOp>() || inst->isa<AffineForOp>() ||
|
||||
inst->isa<AffineIfOp>())
|
||||
instsToRewrite.push_back(inst);
|
||||
});
|
||||
|
||||
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
|
||||
// so we know that we will rewrite them in the same order.
|
||||
for (auto *inst : instsToRewrite) {
|
||||
auto op = cast<OperationInst>(inst);
|
||||
if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
|
||||
if (auto ifOp = inst->dyn_cast<AffineIfOp>()) {
|
||||
if (lowerAffineIf(ifOp))
|
||||
return failure();
|
||||
} else if (auto forOp = op->dyn_cast<AffineForOp>()) {
|
||||
} else if (auto forOp = inst->dyn_cast<AffineForOp>()) {
|
||||
if (lowerAffineFor(forOp))
|
||||
return failure();
|
||||
} else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
|
||||
} else if (lowerAffineApply(inst->cast<AffineApplyOp>())) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
@ -401,13 +401,12 @@ public:
|
||||
explicit VectorTransferExpander(MLIRContext *context)
|
||||
: MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(OperationInst *op) const override {
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
if (m_Op<VectorTransferOpTy>().match(op))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
void rewriteOpInst(OperationInst *op,
|
||||
MLFuncGlobalLoweringState *funcWiseState,
|
||||
void rewriteOpInst(Instruction *op, MLFuncGlobalLoweringState *funcWiseState,
|
||||
std::unique_ptr<PatternState> opState,
|
||||
MLFuncLoweringRewriter *rewriter) const override {
|
||||
VectorTransferRewriter<VectorTransferOpTy>(
|
||||
|
@ -246,8 +246,8 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
|
||||
return res;
|
||||
}
|
||||
|
||||
static OperationInst *
|
||||
instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
|
||||
static Instruction *
|
||||
instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap);
|
||||
|
||||
/// Not all Values belong to a program slice scoped within the immediately
|
||||
@ -391,7 +391,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
|
||||
/// - constant splat is replaced by constant splat of `hwVectorType`.
|
||||
/// TODO(ntv): add more substitutions on a per-need basis.
|
||||
static SmallVector<NamedAttribute, 1>
|
||||
materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
|
||||
materializeAttributes(Instruction *opInst, VectorType hwVectorType) {
|
||||
SmallVector<NamedAttribute, 1> res;
|
||||
for (auto a : opInst->getAttrs()) {
|
||||
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
|
||||
@ -411,8 +411,8 @@ materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
|
||||
/// substitutionsMap.
|
||||
///
|
||||
/// If the underlying substitution fails, this fails too and returns nullptr.
|
||||
static OperationInst *
|
||||
instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
|
||||
static Instruction *
|
||||
instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
assert(!opInst->isa<VectorTransferReadOp>() &&
|
||||
"Should call the function specialized for VectorTransferReadOp");
|
||||
@ -488,7 +488,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
|
||||
/// `hwVectorType` int the covering of the super-vector type. For a more
|
||||
/// detailed description of the problem, see the description of
|
||||
/// reindexAffineIndices.
|
||||
static OperationInst *
|
||||
static Instruction *
|
||||
instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
|
||||
ArrayRef<unsigned> hwVectorInstance,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
@ -512,7 +512,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
|
||||
/// `hwVectorType` int the covering of th3e super-vector type. For a more
|
||||
/// detailed description of the problem, see the description of
|
||||
/// reindexAffineIndices.
|
||||
static OperationInst *
|
||||
static Instruction *
|
||||
instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
|
||||
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
@ -555,21 +555,20 @@ static bool instantiateMaterialization(Instruction *inst,
|
||||
|
||||
// Create a builder here for unroll-and-jam effects.
|
||||
FuncBuilder b(inst);
|
||||
auto *opInst = cast<OperationInst>(inst);
|
||||
// AffineApplyOp are ignored: instantiating the proper vector op will take
|
||||
// care of AffineApplyOps by composing them properly.
|
||||
if (opInst->isa<AffineApplyOp>()) {
|
||||
if (inst->isa<AffineApplyOp>()) {
|
||||
return false;
|
||||
}
|
||||
if (opInst->getNumBlockLists() != 0)
|
||||
if (inst->getNumBlockLists() != 0)
|
||||
return inst->emitError("NYI path Op with region");
|
||||
|
||||
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
|
||||
if (auto write = inst->dyn_cast<VectorTransferWriteOp>()) {
|
||||
auto *clone = instantiate(&b, write, state->hwVectorType,
|
||||
state->hwVectorInstance, state->substitutionsMap);
|
||||
return clone == nullptr;
|
||||
}
|
||||
if (auto read = opInst->dyn_cast<VectorTransferReadOp>()) {
|
||||
if (auto read = inst->dyn_cast<VectorTransferReadOp>()) {
|
||||
auto *clone = instantiate(&b, read, state->hwVectorType,
|
||||
state->hwVectorInstance, state->substitutionsMap);
|
||||
if (!clone) {
|
||||
@ -582,19 +581,19 @@ static bool instantiateMaterialization(Instruction *inst,
|
||||
// The only op with 0 results reaching this point must, by construction, be
|
||||
// VectorTransferWriteOps and have been caught above. Ops with >= 2 results
|
||||
// are not yet supported. So just support 1 result.
|
||||
if (opInst->getNumResults() != 1) {
|
||||
if (inst->getNumResults() != 1) {
|
||||
return inst->emitError("NYI: ops with != 1 results");
|
||||
}
|
||||
if (opInst->getResult(0)->getType() != state->superVectorType) {
|
||||
if (inst->getResult(0)->getType() != state->superVectorType) {
|
||||
return inst->emitError("Op does not return a supervector.");
|
||||
}
|
||||
auto *clone =
|
||||
instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap);
|
||||
instantiate(&b, inst, state->hwVectorType, state->substitutionsMap);
|
||||
if (!clone) {
|
||||
return true;
|
||||
}
|
||||
state->substitutionsMap->insert(
|
||||
std::make_pair(opInst->getResult(0), clone->getResult(0)));
|
||||
std::make_pair(inst->getResult(0), clone->getResult(0)));
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -645,7 +644,7 @@ static bool emitSlice(MaterializationState *state,
|
||||
}
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
|
||||
LLVM_DEBUG(cast<OperationInst>((*slice)[0])->getFunction()->print(dbgs()));
|
||||
LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs()));
|
||||
|
||||
// slice are topologically sorted, we can just erase them in reverse
|
||||
// order. Reverse iterator does not just work simply with an operator*
|
||||
@ -677,7 +676,7 @@ static bool emitSlice(MaterializationState *state,
|
||||
/// scope.
|
||||
/// TODO(ntv): please document return value.
|
||||
static bool materialize(Function *f,
|
||||
const SetVector<OperationInst *> &terminators,
|
||||
const SetVector<Instruction *> &terminators,
|
||||
MaterializationState *state) {
|
||||
DenseSet<Instruction *> seen;
|
||||
DominanceInfo domInfo(f);
|
||||
@ -757,18 +756,17 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
|
||||
// Capture terminators; i.e. vector_transfer_write ops involving a strict
|
||||
// super-vector of subVectorType.
|
||||
auto filter = [subVectorType](const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
if (!opInst.isa<VectorTransferWriteOp>()) {
|
||||
if (!inst.isa<VectorTransferWriteOp>()) {
|
||||
return false;
|
||||
}
|
||||
return matcher::operatesOnSuperVectors(opInst, subVectorType);
|
||||
return matcher::operatesOnSuperVectors(inst, subVectorType);
|
||||
};
|
||||
auto pat = Op(filter);
|
||||
SmallVector<NestedMatch, 8> matches;
|
||||
pat.match(f, &matches);
|
||||
SetVector<OperationInst *> terminators;
|
||||
SetVector<Instruction *> terminators;
|
||||
for (auto m : matches) {
|
||||
terminators.insert(cast<OperationInst>(m.getMatchedInstruction()));
|
||||
terminators.insert(m.getMatchedInstruction());
|
||||
}
|
||||
|
||||
auto fail = materialize(f, terminators, &state);
|
||||
|
@ -75,12 +75,12 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
|
||||
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
void visitInstruction(OperationInst *opInst);
|
||||
void visitInstruction(Instruction *opInst);
|
||||
|
||||
// A list of memref's that are potentially dead / could be eliminated.
|
||||
SmallPtrSet<Value *, 4> memrefsToErase;
|
||||
// Load op's whose results were replaced by those forwarded from stores.
|
||||
std::vector<OperationInst *> loadOpsToErase;
|
||||
std::vector<Instruction *> loadOpsToErase;
|
||||
|
||||
DominanceInfo *domInfo = nullptr;
|
||||
PostDominanceInfo *postDomInfo = nullptr;
|
||||
@ -100,22 +100,22 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() {
|
||||
|
||||
// This is a straightforward implementation not optimized for speed. Optimize
|
||||
// this in the future if needed.
|
||||
void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
|
||||
OperationInst *lastWriteStoreOp = nullptr;
|
||||
void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) {
|
||||
Instruction *lastWriteStoreOp = nullptr;
|
||||
|
||||
auto loadOp = opInst->dyn_cast<LoadOp>();
|
||||
if (!loadOp)
|
||||
return;
|
||||
|
||||
OperationInst *loadOpInst = opInst;
|
||||
Instruction *loadOpInst = opInst;
|
||||
|
||||
// First pass over the use list to get minimum number of surrounding
|
||||
// loops common between the load op and the store op, with min taken across
|
||||
// all store ops.
|
||||
SmallVector<OperationInst *, 8> storeOps;
|
||||
SmallVector<Instruction *, 8> storeOps;
|
||||
unsigned minSurroundingLoops = getNestingDepth(*loadOpInst);
|
||||
for (InstOperand &use : loadOp->getMemRef()->getUses()) {
|
||||
auto storeOp = cast<OperationInst>(use.getOwner())->dyn_cast<StoreOp>();
|
||||
auto storeOp = use.getOwner()->dyn_cast<StoreOp>();
|
||||
if (!storeOp)
|
||||
continue;
|
||||
auto *storeOpInst = storeOp->getInstruction();
|
||||
@ -131,11 +131,11 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
|
||||
// and loadOp.
|
||||
// The list of store op candidates for forwarding - need to satisfy the
|
||||
// conditions listed at the top.
|
||||
SmallVector<OperationInst *, 8> fwdingCandidates;
|
||||
SmallVector<Instruction *, 8> fwdingCandidates;
|
||||
// Store ops that have a dependence into the load (even if they aren't
|
||||
// forwarding candidates). Each forwarding candidate will be checked for a
|
||||
// post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
|
||||
SmallVector<OperationInst *, 8> depSrcStores;
|
||||
SmallVector<Instruction *, 8> depSrcStores;
|
||||
for (auto *storeOpInst : storeOps) {
|
||||
MemRefAccess srcAccess(storeOpInst);
|
||||
MemRefAccess destAccess(loadOpInst);
|
||||
@ -197,7 +197,7 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
|
||||
// that postdominates all 'depSrcStores' (if such a store exists) is the
|
||||
// unique store providing the value to the load, i.e., provably the last
|
||||
// writer to that memref loc.
|
||||
if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) {
|
||||
if (llvm::all_of(depSrcStores, [&](Instruction *depStore) {
|
||||
return postDomInfo->postDominates(storeOpInst, depStore);
|
||||
})) {
|
||||
lastWriteStoreOp = storeOpInst;
|
||||
@ -246,24 +246,22 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) {
|
||||
// to do this as well, but we'll do it here since we collected these anyway.
|
||||
for (auto *memref : memrefsToErase) {
|
||||
// If the memref hasn't been alloc'ed in this function, skip.
|
||||
OperationInst *defInst = memref->getDefiningInst();
|
||||
Instruction *defInst = memref->getDefiningInst();
|
||||
if (!defInst || !defInst->isa<AllocOp>())
|
||||
// TODO(mlir-team): if the memref was returned by a 'call' instruction, we
|
||||
// could still erase it if the call had no side-effects.
|
||||
continue;
|
||||
if (std::any_of(memref->use_begin(), memref->use_end(),
|
||||
[&](InstOperand &use) {
|
||||
auto *ownerInst = cast<OperationInst>(use.getOwner());
|
||||
auto *ownerInst = use.getOwner();
|
||||
return (!ownerInst->isa<StoreOp>() &&
|
||||
!ownerInst->isa<DeallocOp>());
|
||||
}))
|
||||
continue;
|
||||
|
||||
// Erase all stores, the dealloc, and the alloc on the memref.
|
||||
for (auto it = memref->use_begin(), e = memref->use_end(); it != e;) {
|
||||
auto &use = *(it++);
|
||||
cast<OperationInst>(use.getOwner())->erase();
|
||||
}
|
||||
for (auto &use : llvm::make_early_inc_range(memref->getUses()))
|
||||
use.getOwner()->erase();
|
||||
defInst->erase();
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
|
||||
// Returns the position of the tag memref operand given a DMA instruction.
|
||||
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
||||
// added. TODO(b/117228571)
|
||||
static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
|
||||
static unsigned getTagMemRefPos(const Instruction &dmaInst) {
|
||||
assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>());
|
||||
if (dmaInst.isa<DmaStartOp>()) {
|
||||
// Second to last operand.
|
||||
@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
|
||||
// deleted and replaced by a prologue, a new steady-state loop and an
|
||||
// epilogue).
|
||||
forOps.clear();
|
||||
f->walkPostOrder([&](OperationInst *opInst) {
|
||||
f->walkPostOrder([&](Instruction *opInst) {
|
||||
if (auto forOp = opInst->dyn_cast<AffineForOp>())
|
||||
forOps.push_back(forOp);
|
||||
});
|
||||
@ -180,33 +180,26 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
|
||||
// Identify matching DMA start/finish instructions to overlap computation with.
|
||||
static void findMatchingStartFinishInsts(
|
||||
OpPointer<AffineForOp> forOp,
|
||||
SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
|
||||
&startWaitPairs) {
|
||||
SmallVectorImpl<std::pair<Instruction *, Instruction *>> &startWaitPairs) {
|
||||
|
||||
// Collect outgoing DMA instructions - needed to check for dependences below.
|
||||
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
|
||||
for (auto &inst : *forOp->getBody()) {
|
||||
auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
if (!opInst)
|
||||
continue;
|
||||
OpPointer<DmaStartOp> dmaStartOp;
|
||||
if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) &&
|
||||
if ((dmaStartOp = inst.dyn_cast<DmaStartOp>()) &&
|
||||
dmaStartOp->isSrcMemorySpaceFaster())
|
||||
outgoingDmaOps.push_back(dmaStartOp);
|
||||
}
|
||||
|
||||
SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
|
||||
SmallVector<Instruction *, 4> dmaStartInsts, dmaFinishInsts;
|
||||
for (auto &inst : *forOp->getBody()) {
|
||||
auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
if (!opInst)
|
||||
continue;
|
||||
// Collect DMA finish instructions.
|
||||
if (opInst->isa<DmaWaitOp>()) {
|
||||
dmaFinishInsts.push_back(opInst);
|
||||
if (inst.isa<DmaWaitOp>()) {
|
||||
dmaFinishInsts.push_back(&inst);
|
||||
continue;
|
||||
}
|
||||
OpPointer<DmaStartOp> dmaStartOp;
|
||||
if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>()))
|
||||
if (!(dmaStartOp = inst.dyn_cast<DmaStartOp>()))
|
||||
continue;
|
||||
// Only DMAs incoming into higher memory spaces are pipelined for now.
|
||||
// TODO(bondhugula): handle outgoing DMA pipelining.
|
||||
@ -236,7 +229,7 @@ static void findMatchingStartFinishInsts(
|
||||
}
|
||||
}
|
||||
if (!escapingUses)
|
||||
dmaStartInsts.push_back(opInst);
|
||||
dmaStartInsts.push_back(&inst);
|
||||
}
|
||||
|
||||
// For each start instruction, we look for a matching finish instruction.
|
||||
@ -262,7 +255,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
|
||||
SmallVector<std::pair<Instruction *, Instruction *>, 4> startWaitPairs;
|
||||
findMatchingStartFinishInsts(forOp, startWaitPairs);
|
||||
|
||||
if (startWaitPairs.empty()) {
|
||||
@ -335,7 +328,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
|
||||
} else {
|
||||
// If a slice wasn't created, the reachable affine_apply op's from its
|
||||
// operands are the ones that go with it.
|
||||
SmallVector<OperationInst *, 4> affineApplyInsts;
|
||||
SmallVector<Instruction *, 4> affineApplyInsts;
|
||||
SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
|
||||
getReachableAffineApplyOps(operands, affineApplyInsts);
|
||||
for (const auto *inst : affineApplyInsts) {
|
||||
@ -356,13 +349,13 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
|
||||
for (auto &inst : *forOp->getBody()) {
|
||||
assert(instShiftMap.find(&inst) != instShiftMap.end());
|
||||
shifts[s++] = instShiftMap[&inst];
|
||||
LLVM_DEBUG(
|
||||
// Tagging instructions with shifts for debugging purposes.
|
||||
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
|
||||
FuncBuilder b(opInst);
|
||||
opInst->setAttr(b.getIdentifier("shift"),
|
||||
b.getI64IntegerAttr(shifts[s - 1]));
|
||||
});
|
||||
|
||||
// Tagging instructions with shifts for debugging purposes.
|
||||
LLVM_DEBUG({
|
||||
FuncBuilder b(&inst);
|
||||
inst.setAttr(b.getIdentifier("shift"),
|
||||
b.getI64IntegerAttr(shifts[s - 1]));
|
||||
});
|
||||
}
|
||||
|
||||
if (!isInstwiseShiftValid(forOp, shifts)) {
|
||||
|
@ -64,7 +64,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
|
||||
}
|
||||
|
||||
PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
|
||||
f->walk([&](OperationInst *opInst) {
|
||||
f->walk([&](Instruction *opInst) {
|
||||
for (auto attr : opInst->getAttrs()) {
|
||||
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
|
||||
MutableAffineMap mMap(mapAttr.getValue());
|
||||
|
@ -38,13 +38,13 @@ public:
|
||||
worklist.reserve(64);
|
||||
|
||||
// Add all operations to the worklist.
|
||||
fn->walk([&](OperationInst *inst) { addToWorklist(inst); });
|
||||
fn->walk([&](Instruction *inst) { addToWorklist(inst); });
|
||||
}
|
||||
|
||||
/// Perform the rewrites.
|
||||
void simplifyFunction();
|
||||
|
||||
void addToWorklist(OperationInst *op) {
|
||||
void addToWorklist(Instruction *op) {
|
||||
// Check to see if the worklist already contains this op.
|
||||
if (worklistMap.count(op))
|
||||
return;
|
||||
@ -53,7 +53,7 @@ public:
|
||||
worklist.push_back(op);
|
||||
}
|
||||
|
||||
OperationInst *popFromWorklist() {
|
||||
Instruction *popFromWorklist() {
|
||||
auto *op = worklist.back();
|
||||
worklist.pop_back();
|
||||
|
||||
@ -65,7 +65,7 @@ public:
|
||||
|
||||
/// If the specified operation is in the worklist, remove it. If not, this is
|
||||
/// a no-op.
|
||||
void removeFromWorklist(OperationInst *op) {
|
||||
void removeFromWorklist(Instruction *op) {
|
||||
auto it = worklistMap.find(op);
|
||||
if (it != worklistMap.end()) {
|
||||
assert(worklist[it->second] == op && "malformed worklist data structure");
|
||||
@ -77,7 +77,7 @@ public:
|
||||
protected:
|
||||
// Implement the hook for creating operations, and make sure that newly
|
||||
// created ops are added to the worklist for processing.
|
||||
OperationInst *createOperation(const OperationState &state) override {
|
||||
Instruction *createOperation(const OperationState &state) override {
|
||||
auto *result = builder.createOperation(state);
|
||||
addToWorklist(result);
|
||||
return result;
|
||||
@ -85,20 +85,18 @@ protected:
|
||||
|
||||
// If an operation is about to be removed, make sure it is not in our
|
||||
// worklist anymore because we'd get dangling references to it.
|
||||
void notifyOperationRemoved(OperationInst *op) override {
|
||||
void notifyOperationRemoved(Instruction *op) override {
|
||||
removeFromWorklist(op);
|
||||
}
|
||||
|
||||
// When the root of a pattern is about to be replaced, it can trigger
|
||||
// simplifications to its users - make sure to add them to the worklist
|
||||
// before the root is changed.
|
||||
void notifyRootReplaced(OperationInst *op) override {
|
||||
void notifyRootReplaced(Instruction *op) override {
|
||||
for (auto *result : op->getResults())
|
||||
// TODO: Add a result->getUsers() iterator.
|
||||
for (auto &user : result->getUses()) {
|
||||
if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
|
||||
addToWorklist(op);
|
||||
}
|
||||
for (auto &user : result->getUses())
|
||||
addToWorklist(user.getOwner());
|
||||
|
||||
// TODO: Walk the operand list dropping them as we go. If any of them
|
||||
// drop to zero uses, then add them to the worklist to allow them to be
|
||||
@ -116,13 +114,13 @@ private:
|
||||
/// need to be revisited, plus their index in the worklist. This allows us to
|
||||
/// efficiently remove operations from the worklist when they are erased from
|
||||
/// the function, even if they aren't the root of a pattern.
|
||||
std::vector<OperationInst *> worklist;
|
||||
DenseMap<OperationInst *, unsigned> worklistMap;
|
||||
std::vector<Instruction *> worklist;
|
||||
DenseMap<Instruction *, unsigned> worklistMap;
|
||||
|
||||
/// As part of canonicalization, we move constants to the top of the entry
|
||||
/// block of the current function and de-duplicate them. This keeps track of
|
||||
/// constants we have done this for.
|
||||
DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
|
||||
DenseMap<std::pair<Attribute, Type>, Instruction *> uniquedConstants;
|
||||
};
|
||||
}; // end anonymous namespace
|
||||
|
||||
@ -229,10 +227,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
|
||||
// revisit them.
|
||||
//
|
||||
// TODO: Add a result->getUsers() iterator.
|
||||
for (auto &operand : op->getResult(i)->getUses()) {
|
||||
if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
|
||||
addToWorklist(op);
|
||||
}
|
||||
for (auto &operand : op->getResult(i)->getUses())
|
||||
addToWorklist(operand.getOwner());
|
||||
|
||||
res->replaceAllUsesWith(cstValue);
|
||||
}
|
||||
@ -267,10 +263,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
|
||||
if (res->use_empty()) // ignore dead uses.
|
||||
continue;
|
||||
|
||||
for (auto &operand : op->getResult(i)->getUses()) {
|
||||
if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
|
||||
addToWorklist(op);
|
||||
}
|
||||
for (auto &operand : op->getResult(i)->getUses())
|
||||
addToWorklist(operand.getOwner());
|
||||
res->replaceAllUsesWith(resultValues[i]);
|
||||
}
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
|
||||
|
||||
// Replaces all IV uses to its single iteration value.
|
||||
auto *iv = forOp->getInductionVar();
|
||||
OperationInst *forInst = forOp->getInstruction();
|
||||
Instruction *forInst = forOp->getInstruction();
|
||||
if (!iv->use_empty()) {
|
||||
if (forOp->hasConstantLowerBound()) {
|
||||
auto *mlFunc = forInst->getFunction();
|
||||
@ -135,7 +135,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
|
||||
/// their body into the containing Block.
|
||||
void mlir::promoteSingleIterationLoops(Function *f) {
|
||||
// Gathers all innermost loops through a post order pruned walk.
|
||||
f->walkPostOrder([](OperationInst *inst) {
|
||||
f->walkPostOrder([](Instruction *inst) {
|
||||
if (auto forOp = inst->dyn_cast<AffineForOp>())
|
||||
promoteIfSingleIteration(forOp);
|
||||
});
|
||||
@ -394,11 +394,10 @@ bool mlir::loopUnrollByFactor(OpPointer<AffineForOp> forOp,
|
||||
return false;
|
||||
|
||||
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
|
||||
OperationInst *forInst = forOp->getInstruction();
|
||||
Instruction *forInst = forOp->getInstruction();
|
||||
if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
|
||||
FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst));
|
||||
auto cleanupForInst =
|
||||
cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
|
||||
auto cleanupForInst = builder.clone(*forInst)->cast<AffineForOp>();
|
||||
auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder);
|
||||
assert(clLbMap &&
|
||||
"cleanup loop lower bound map for single result bound maps can "
|
||||
|
@ -37,7 +37,7 @@ using namespace mlir;
|
||||
/// Return true if this operation dereferences one or more memref's.
|
||||
// Temporary utility: will be replaced when this is modeled through
|
||||
// side-effects/op traits. TODO(b/117228571)
|
||||
static bool isMemRefDereferencingOp(const OperationInst &op) {
|
||||
static bool isMemRefDereferencingOp(const Instruction &op) {
|
||||
if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
|
||||
op.isa<DmaWaitOp>())
|
||||
return true;
|
||||
@ -76,12 +76,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
||||
std::make_unique<PostDominanceInfo>(postDomInstFilter->getFunction());
|
||||
|
||||
// The ops where memref replacement succeeds are replaced with new ones.
|
||||
SmallVector<OperationInst *, 8> opsToErase;
|
||||
SmallVector<Instruction *, 8> opsToErase;
|
||||
|
||||
// Walk all uses of old memref. Operation using the memref gets replaced.
|
||||
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
|
||||
InstOperand &use = *(it++);
|
||||
auto *opInst = cast<OperationInst>(use.getOwner());
|
||||
for (auto &use : llvm::make_early_inc_range(oldMemRef->getUses())) {
|
||||
auto *opInst = use.getOwner();
|
||||
|
||||
// Skip this use if it's not dominated by domInstFilter.
|
||||
if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
|
||||
@ -217,8 +216,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
||||
/// uses besides this opInst; otherwise returns the list of affine_apply
|
||||
/// operations created in output argument `sliceOps`.
|
||||
void mlir::createAffineComputationSlice(
|
||||
OperationInst *opInst,
|
||||
SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps) {
|
||||
Instruction *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps) {
|
||||
// Collect all operands that are results of affine apply ops.
|
||||
SmallVector<Value *, 4> subOperands;
|
||||
subOperands.reserve(opInst->getNumOperands());
|
||||
@ -230,7 +228,7 @@ void mlir::createAffineComputationSlice(
|
||||
}
|
||||
|
||||
// Gather sequence of AffineApplyOps reachable from 'subOperands'.
|
||||
SmallVector<OperationInst *, 4> affineApplyOps;
|
||||
SmallVector<Instruction *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps(subOperands, affineApplyOps);
|
||||
// Skip transforming if there are no affine maps to compose.
|
||||
if (affineApplyOps.empty())
|
||||
@ -341,8 +339,7 @@ bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) {
|
||||
}
|
||||
|
||||
void mlir::remapFunctionAttrs(
|
||||
OperationInst &op,
|
||||
const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
Instruction &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
for (auto attr : op.getAttrs()) {
|
||||
// Do the remapping, if we got the same thing back, then it must contain
|
||||
// functions that aren't getting remapped.
|
||||
|
@ -110,17 +110,13 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
|
||||
// Only filter instructions that operate on a strict super-vector and have one
|
||||
// return. This makes testing easier.
|
||||
auto filter = [subVectorType](const Instruction &inst) {
|
||||
auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
if (!opInst) {
|
||||
return false;
|
||||
}
|
||||
assert(subVectorType.getElementType() ==
|
||||
Type::getF32(subVectorType.getContext()) &&
|
||||
"Only f32 supported for now");
|
||||
if (!matcher::operatesOnSuperVectors(*opInst, subVectorType)) {
|
||||
if (!matcher::operatesOnSuperVectors(inst, subVectorType)) {
|
||||
return false;
|
||||
}
|
||||
if (opInst->getNumResults() != 1) {
|
||||
if (inst.getNumResults() != 1) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -129,7 +125,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
|
||||
SmallVector<NestedMatch, 8> matches;
|
||||
pat.match(f, &matches);
|
||||
for (auto m : matches) {
|
||||
auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
|
||||
auto *opInst = m.getMatchedInstruction();
|
||||
// This is a unit test that only checks and prints shape ratio.
|
||||
// As a consequence we write only Ops with a single return type for the
|
||||
// purpose of this test. If we need to test more intricate behavior in the
|
||||
@ -159,8 +155,7 @@ static NestedPattern patternTestSlicingOps() {
|
||||
using matcher::Op;
|
||||
// Match all OpInstructions with the kTestSlicingOpName name.
|
||||
auto filter = [](const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
return opInst.getName().getStringRef() == kTestSlicingOpName;
|
||||
return inst.getName().getStringRef() == kTestSlicingOpName;
|
||||
};
|
||||
return Op(filter);
|
||||
}
|
||||
@ -209,8 +204,7 @@ void VectorizerTestPass::testSlicing(Function *f) {
|
||||
}
|
||||
|
||||
static bool customOpWithAffineMapAttribute(const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
return opInst.getName().getStringRef() ==
|
||||
return inst.getName().getStringRef() ==
|
||||
VectorizerTestPass::kTestAffineMapOpName;
|
||||
}
|
||||
|
||||
@ -222,7 +216,7 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
|
||||
SmallVector<AffineMap, 4> maps;
|
||||
maps.reserve(matches.size());
|
||||
for (auto m : llvm::reverse(matches)) {
|
||||
auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
|
||||
auto *opInst = m.getMatchedInstruction();
|
||||
auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
|
||||
.cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
@ -236,13 +230,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
|
||||
}
|
||||
|
||||
static bool affineApplyOp(const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
return opInst.isa<AffineApplyOp>();
|
||||
return inst.isa<AffineApplyOp>();
|
||||
}
|
||||
|
||||
static bool singleResultAffineApplyOpWithoutUses(const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
auto app = opInst.dyn_cast<AffineApplyOp>();
|
||||
auto app = inst.dyn_cast<AffineApplyOp>();
|
||||
return app && app->use_empty();
|
||||
}
|
||||
|
||||
@ -259,8 +251,7 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) {
|
||||
SmallVector<NestedMatch, 8> matches;
|
||||
pattern.match(f, &matches);
|
||||
for (auto m : matches) {
|
||||
auto app =
|
||||
cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>();
|
||||
auto app = m.getMatchedInstruction()->cast<AffineApplyOp>();
|
||||
FuncBuilder b(m.getMatchedInstruction());
|
||||
SmallVector<Value *, 8> operands(app->getOperands());
|
||||
makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands);
|
||||
|
@ -723,22 +723,22 @@ namespace {
|
||||
|
||||
struct VectorizationState {
|
||||
/// Adds an entry of pre/post vectorization instructions in the state.
|
||||
void registerReplacement(OperationInst *key, OperationInst *value);
|
||||
void registerReplacement(Instruction *key, Instruction *value);
|
||||
/// When the current vectorization pattern is successful, this erases the
|
||||
/// instructions that were marked for erasure in the proper order and resets
|
||||
/// the internal state for the next pattern.
|
||||
void finishVectorizationPattern();
|
||||
|
||||
// In-order tracking of original OperationInst that have been vectorized.
|
||||
// In-order tracking of original Instruction that have been vectorized.
|
||||
// Erase in reverse order.
|
||||
SmallVector<OperationInst *, 16> toErase;
|
||||
// Set of OperationInst that have been vectorized (the values in the
|
||||
SmallVector<Instruction *, 16> toErase;
|
||||
// Set of Instruction that have been vectorized (the values in the
|
||||
// vectorizationMap for hashed access). The vectorizedSet is used in
|
||||
// particular to filter the instructions that have already been vectorized by
|
||||
// this pattern, when iterating over nested loops in this pattern.
|
||||
DenseSet<OperationInst *> vectorizedSet;
|
||||
// Map of old scalar OperationInst to new vectorized OperationInst.
|
||||
DenseMap<OperationInst *, OperationInst *> vectorizationMap;
|
||||
DenseSet<Instruction *> vectorizedSet;
|
||||
// Map of old scalar Instruction to new vectorized Instruction.
|
||||
DenseMap<Instruction *, Instruction *> vectorizationMap;
|
||||
// Map of old scalar Value to new vectorized Value.
|
||||
DenseMap<const Value *, Value *> replacementMap;
|
||||
// The strategy drives which loop to vectorize by which amount.
|
||||
@ -747,17 +747,17 @@ struct VectorizationState {
|
||||
// vectorizeOperations function. They consist of the subset of load operations
|
||||
// that have been vectorized. They can be retrieved from `vectorizationMap`
|
||||
// but it is convenient to keep track of them in a separate data structure.
|
||||
DenseSet<OperationInst *> roots;
|
||||
DenseSet<Instruction *> roots;
|
||||
// Terminator instructions for the worklist in the vectorizeOperations
|
||||
// function. They consist of the subset of store operations that have been
|
||||
// vectorized. They can be retrieved from `vectorizationMap` but it is
|
||||
// convenient to keep track of them in a separate data structure. Since they
|
||||
// do not necessarily belong to use-def chains starting from loads (e.g
|
||||
// storing a constant), we need to handle them in a post-pass.
|
||||
DenseSet<OperationInst *> terminators;
|
||||
DenseSet<Instruction *> terminators;
|
||||
// Checks that the type of `inst` is StoreOp and adds it to the terminators
|
||||
// set.
|
||||
void registerTerminator(OperationInst *inst);
|
||||
void registerTerminator(Instruction *inst);
|
||||
|
||||
private:
|
||||
void registerReplacement(const Value *key, Value *value);
|
||||
@ -765,8 +765,8 @@ private:
|
||||
|
||||
} // end namespace
|
||||
|
||||
void VectorizationState::registerReplacement(OperationInst *key,
|
||||
OperationInst *value) {
|
||||
void VectorizationState::registerReplacement(Instruction *key,
|
||||
Instruction *value) {
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: ");
|
||||
LLVM_DEBUG(key->print(dbgs()));
|
||||
LLVM_DEBUG(dbgs() << " into ");
|
||||
@ -785,7 +785,7 @@ void VectorizationState::registerReplacement(OperationInst *key,
|
||||
}
|
||||
}
|
||||
|
||||
void VectorizationState::registerTerminator(OperationInst *inst) {
|
||||
void VectorizationState::registerTerminator(Instruction *inst) {
|
||||
assert(inst->isa<StoreOp>() && "terminator must be a StoreOp");
|
||||
assert(terminators.count(inst) == 0 &&
|
||||
"terminator was already inserted previously");
|
||||
@ -867,17 +867,16 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step,
|
||||
if (!matcher::isLoadOrStore(inst)) {
|
||||
return false;
|
||||
}
|
||||
auto *opInst = cast<OperationInst>(&inst);
|
||||
return state->vectorizationMap.count(opInst) == 0 &&
|
||||
state->vectorizedSet.count(opInst) == 0 &&
|
||||
state->roots.count(opInst) == 0 &&
|
||||
state->terminators.count(opInst) == 0;
|
||||
return state->vectorizationMap.count(&inst) == 0 &&
|
||||
state->vectorizedSet.count(&inst) == 0 &&
|
||||
state->roots.count(&inst) == 0 &&
|
||||
state->terminators.count(&inst) == 0;
|
||||
};
|
||||
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
|
||||
SmallVector<NestedMatch, 8> loadAndStoresMatches;
|
||||
loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches);
|
||||
for (auto ls : loadAndStoresMatches) {
|
||||
auto *opInst = cast<OperationInst>(ls.getMatchedInstruction());
|
||||
auto *opInst = ls.getMatchedInstruction();
|
||||
auto load = opInst->dyn_cast<LoadOp>();
|
||||
auto store = opInst->dyn_cast<StoreOp>();
|
||||
LLVM_DEBUG(opInst->print(dbgs()));
|
||||
@ -900,7 +899,7 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step,
|
||||
static FilterFunctionType
|
||||
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
|
||||
return [fastestVaryingMemRefDimension](const Instruction &forInst) {
|
||||
auto loop = cast<OperationInst>(forInst).cast<AffineForOp>();
|
||||
auto loop = forInst.cast<AffineForOp>();
|
||||
return isVectorizableLoopAlongFastestVaryingMemRefDim(
|
||||
loop, fastestVaryingMemRefDimension);
|
||||
};
|
||||
@ -915,7 +914,7 @@ static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
|
||||
/// recursively in DFS post-order.
|
||||
static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
|
||||
auto *loopInst = oneMatch.getMatchedInstruction();
|
||||
auto loop = cast<OperationInst>(loopInst)->cast<AffineForOp>();
|
||||
auto loop = loopInst->cast<AffineForOp>();
|
||||
auto childrenMatches = oneMatch.getMatchedChildren();
|
||||
|
||||
// 1. DFS postorder recursion, if any of my children fails, I fail too.
|
||||
@ -977,15 +976,14 @@ static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant,
|
||||
Location loc = inst->getLoc();
|
||||
auto vectorType = type.cast<VectorType>();
|
||||
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
|
||||
auto *constantOpInst = cast<OperationInst>(constant.getInstruction());
|
||||
auto *constantOpInst = constant.getInstruction();
|
||||
|
||||
OperationState state(
|
||||
b.getContext(), loc, constantOpInst->getName().getStringRef(), {},
|
||||
{vectorType},
|
||||
{make_pair(Identifier::get("value", b.getContext()), attr)});
|
||||
|
||||
auto *splat = cast<OperationInst>(b.createOperation(state));
|
||||
return splat->getResult(0);
|
||||
return b.createOperation(state)->getResult(0);
|
||||
}
|
||||
|
||||
/// Returns a uniqu'ed VectorType.
|
||||
@ -997,8 +995,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
|
||||
if (!VectorType::isValidElementType(v->getType())) {
|
||||
return Type();
|
||||
}
|
||||
auto *definingOpInst = cast<OperationInst>(v->getDefiningInst());
|
||||
if (state.vectorizedSet.count(definingOpInst) > 0) {
|
||||
if (state.vectorizedSet.count(v->getDefiningInst()) > 0) {
|
||||
return v->getType().cast<VectorType>();
|
||||
}
|
||||
return VectorType::get(state.strategy->vectorSizes, v->getType());
|
||||
@ -1029,9 +1026,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
|
||||
VectorizationState *state) {
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
|
||||
LLVM_DEBUG(operand->print(dbgs()));
|
||||
auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst());
|
||||
// 1. If this value has already been vectorized this round, we are done.
|
||||
if (state->vectorizedSet.count(definingInstruction) > 0) {
|
||||
if (state->vectorizedSet.count(operand->getDefiningInst()) > 0) {
|
||||
LLVM_DEBUG(dbgs() << " -> already vector operand");
|
||||
return operand;
|
||||
}
|
||||
@ -1062,7 +1058,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
/// Encodes OperationInst-specific behavior for vectorization. In general we
|
||||
/// Encodes Instruction-specific behavior for vectorization. In general we
|
||||
/// assume that all operands of an op must be vectorized but this is not always
|
||||
/// true. In the future, it would be nice to have a trait that describes how a
|
||||
/// particular operation vectorizes. For now we implement the case distinction
|
||||
@ -1071,9 +1067,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
|
||||
/// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized.
|
||||
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
|
||||
/// do one-off logic here; ideally it would be TableGen'd.
|
||||
static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
|
||||
OperationInst *opInst,
|
||||
VectorizationState *state) {
|
||||
static Instruction *vectorizeOneInstruction(FuncBuilder *b, Instruction *opInst,
|
||||
VectorizationState *state) {
|
||||
// Sanity checks.
|
||||
assert(!opInst->isa<LoadOp>() &&
|
||||
"all loads must have already been fully vectorized independently");
|
||||
@ -1094,7 +1089,7 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
|
||||
LLVM_DEBUG(permutationMap.print(dbgs()));
|
||||
auto transfer = b.create<VectorTransferWriteOp>(
|
||||
opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
|
||||
auto *res = cast<OperationInst>(transfer->getInstruction());
|
||||
auto *res = transfer->getInstruction();
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
|
||||
// "Terminators" (i.e. StoreOps) are erased on the spot.
|
||||
opInst->erase();
|
||||
@ -1119,8 +1114,8 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
|
||||
// Create a clone of the op with the proper operands and return types.
|
||||
// TODO(ntv): The following assumes there is always an op with a fixed
|
||||
// name that works both in scalar mode and vector mode.
|
||||
// TODO(ntv): Is it worth considering an OperationInst.clone operation
|
||||
// which changes the type so we can promote an OperationInst with less
|
||||
// TODO(ntv): Is it worth considering an Instruction.clone operation
|
||||
// which changes the type so we can promote an Instruction with less
|
||||
// boilerplate?
|
||||
OperationState newOp(b->getContext(), opInst->getLoc(),
|
||||
opInst->getName().getStringRef(), operands, types,
|
||||
@ -1129,22 +1124,22 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
|
||||
return b->createOperation(newOp);
|
||||
}
|
||||
|
||||
/// Iterates over the OperationInst in the loop and rewrites them using their
|
||||
/// Iterates over the Instruction in the loop and rewrites them using their
|
||||
/// vectorized counterpart by:
|
||||
/// 1. iteratively building a worklist of uses of the OperationInst vectorized
|
||||
/// 1. iteratively building a worklist of uses of the Instruction vectorized
|
||||
/// so far by this pattern;
|
||||
/// 2. for each OperationInst in the worklist, create the vector form of this
|
||||
/// 2. for each Instruction in the worklist, create the vector form of this
|
||||
/// operation and replace all its uses by the vectorized form. For this step,
|
||||
/// the worklist must be traversed in order;
|
||||
/// 3. verify that all operands of the newly vectorized operation have been
|
||||
/// vectorized by this pattern.
|
||||
static bool vectorizeOperations(VectorizationState *state) {
|
||||
// 1. create initial worklist with the uses of the roots.
|
||||
SetVector<OperationInst *> worklist;
|
||||
auto insertUsesOf = [&worklist, state](OperationInst *vectorized) {
|
||||
SetVector<Instruction *> worklist;
|
||||
auto insertUsesOf = [&worklist, state](Instruction *vectorized) {
|
||||
for (auto *r : vectorized->getResults())
|
||||
for (auto &u : r->getUses()) {
|
||||
auto *inst = cast<OperationInst>(u.getOwner());
|
||||
auto *inst = u.getOwner();
|
||||
// Don't propagate to terminals, a separate pass is needed for those.
|
||||
// TODO(ntv)[b/119759136]: use isa<> once Op is implemented.
|
||||
if (state->terminators.count(inst) > 0) {
|
||||
@ -1166,7 +1161,7 @@ static bool vectorizeOperations(VectorizationState *state) {
|
||||
// 2. Create vectorized form of the instruction.
|
||||
// Insert it just before inst, on success register inst as replaced.
|
||||
FuncBuilder b(inst);
|
||||
auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state);
|
||||
auto *vectorizedInst = vectorizeOneInstruction(&b, inst, state);
|
||||
if (!vectorizedInst) {
|
||||
return true;
|
||||
}
|
||||
@ -1179,7 +1174,7 @@ static bool vectorizeOperations(VectorizationState *state) {
|
||||
|
||||
// 4. Augment the worklist with uses of the instruction we just vectorized.
|
||||
// This preserves the proper order in the worklist.
|
||||
apply(insertUsesOf, ArrayRef<OperationInst *>{inst});
|
||||
apply(insertUsesOf, ArrayRef<Instruction *>{inst});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -1189,8 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) {
|
||||
/// Each root may succeed independently but will otherwise clean after itself if
|
||||
/// anything below it fails.
|
||||
static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
|
||||
auto loop =
|
||||
cast<OperationInst>(m.getMatchedInstruction())->cast<AffineForOp>();
|
||||
auto loop = m.getMatchedInstruction()->cast<AffineForOp>();
|
||||
VectorizationState state;
|
||||
state.strategy = strategy;
|
||||
|
||||
@ -1207,8 +1201,7 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
|
||||
}
|
||||
auto *loopInst = loop->getInstruction();
|
||||
FuncBuilder builder(loopInst);
|
||||
auto clonedLoop =
|
||||
cast<OperationInst>(builder.clone(*loopInst))->cast<AffineForOp>();
|
||||
auto clonedLoop = builder.clone(*loopInst)->cast<AffineForOp>();
|
||||
|
||||
auto fail = doVectorize(m, &state);
|
||||
/// Sets up error handling for this root loop. This is how the root match
|
||||
@ -1248,12 +1241,12 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
|
||||
}
|
||||
|
||||
// Finally, vectorize the terminators. If anything fails to vectorize, skip.
|
||||
auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
|
||||
auto vectorizeOrFail = [&fail, &state](Instruction *inst) {
|
||||
if (fail) {
|
||||
return;
|
||||
}
|
||||
FuncBuilder b(inst);
|
||||
auto *res = vectorizeOneOperationInst(&b, inst, &state);
|
||||
auto *res = vectorizeOneInstruction(&b, inst, &state);
|
||||
if (res == nullptr) {
|
||||
fail = true;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user