From b499277fb648c44907443ce44ec6bcc6b7596039 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 4 Feb 2019 10:38:47 -0800 Subject: [PATCH] Remove remaining usages of OperationInst in lib/Transforms. PiperOrigin-RevId: 232323671 --- mlir/lib/Transforms/CSE.cpp | 41 +++--- mlir/lib/Transforms/ComposeAffineMaps.cpp | 10 +- mlir/lib/Transforms/ConstantFold.cpp | 8 +- mlir/lib/Transforms/DialectConversion.cpp | 31 ++-- mlir/lib/Transforms/DmaGeneration.cpp | 5 +- mlir/lib/Transforms/LoopFusion.cpp | 134 ++++++++---------- mlir/lib/Transforms/LoopTiling.cpp | 5 +- mlir/lib/Transforms/LoopUnroll.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 10 +- mlir/lib/Transforms/LowerAffine.cpp | 12 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 5 +- mlir/lib/Transforms/MaterializeVectors.cpp | 44 +++--- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 30 ++-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 43 +++--- .../Transforms/SimplifyAffineStructures.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 38 +++-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 9 +- mlir/lib/Transforms/Utils/Utils.cpp | 17 +-- .../Vectorization/VectorizerTestPass.cpp | 27 ++-- mlir/lib/Transforms/Vectorize.cpp | 91 ++++++------ 20 files changed, 250 insertions(+), 316 deletions(-) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index e471b6792c59..63a676d7b52a 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -39,10 +39,10 @@ using namespace mlir; namespace { // TODO(riverriddle) Handle commutative operations. -struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const OperationInst *op) { +struct SimpleOperationInfo : public llvm::DenseMapInfo { + static unsigned getHashValue(const Instruction *op) { // Hash the operations based upon their: - // - OperationInst Name + // - Instruction Name // - Attributes // - Result Types // - Operands @@ -51,7 +51,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo { hash_combine_range(op->result_type_begin(), op->result_type_end()), hash_combine_range(op->operand_begin(), op->operand_end())); } - static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) { + static bool isEqual(const Instruction *lhs, const Instruction *rhs) { if (lhs == rhs) return true; if (lhs == getTombstoneKey() || lhs == getEmptyKey() || @@ -89,8 +89,8 @@ struct CSE : public FunctionPass { /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, - llvm::ScopedHashTableVal>; - using ScopedMapTy = llvm::ScopedHashTable>; + using ScopedMapTy = llvm::ScopedHashTable; /// Represents a single entry in the depth first traversal of a CFG. @@ -111,7 +111,7 @@ struct CSE : public FunctionPass { /// Attempt to eliminate a redundant operation. Returns true if the operation /// was marked for removal, false otherwise. - bool simplifyOperation(OperationInst *op); + bool simplifyOperation(Instruction *op); void simplifyBlock(Block *bb); @@ -122,14 +122,14 @@ private: ScopedMapTy knownValues; /// Operations marked as dead and to be erased. - std::vector opsToErase; + std::vector opsToErase; }; } // end anonymous namespace char CSE::passID = 0; /// Attempt to eliminate a redundant operation. -bool CSE::simplifyOperation(OperationInst *op) { +bool CSE::simplifyOperation(Instruction *op) { // TODO(riverriddle) We currently only eliminate non side-effecting // operations. if (!op->hasNoSideEffect()) @@ -166,23 +166,16 @@ bool CSE::simplifyOperation(OperationInst *op) { void CSE::simplifyBlock(Block *bb) { for (auto &i : *bb) { - switch (i.getKind()) { - case Instruction::Kind::OperationInst: { - auto *opInst = cast(&i); + // If the operation is simplified, we don't process any held block lists. + if (simplifyOperation(&i)) + continue; - // If the operation is simplified, we don't process any held block lists. - if (simplifyOperation(opInst)) - continue; - - // Simplify any held blocks. - for (auto &blockList : opInst->getBlockLists()) { - for (auto &b : blockList) { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(&b); - } + // Simplify any held blocks. + for (auto &blockList : i.getBlockLists()) { + for (auto &b : blockList) { + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(&b); } - break; - } } } } diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 4f960ea73afb..4a6430dc9be9 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -48,7 +48,7 @@ namespace { struct ComposeAffineMaps : public FunctionPass, InstWalker { explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(OperationInst *opInst); + void visitInstruction(Instruction *opInst); SmallVector, 8> affineApplyOps; @@ -64,14 +64,12 @@ FunctionPass *mlir::createComposeAffineMapsPass() { } static bool affineApplyOp(const Instruction &inst) { - const auto &opInst = cast(inst); - return opInst.isa(); + return inst.isa(); } -void ComposeAffineMaps::visitInstruction(OperationInst *opInst) { - if (auto afOp = opInst->dyn_cast()) { +void ComposeAffineMaps::visitInstruction(Instruction *opInst) { + if (auto afOp = opInst->dyn_cast()) affineApplyOps.push_back(afOp); - } } PassResult ComposeAffineMaps::runOnFunction(Function *f) { diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 859d0012fac0..54486cdb293c 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -33,11 +33,11 @@ struct ConstantFold : public FunctionPass, InstWalker { // All constants in the function post folding. SmallVector existingConstants; // Operations that were folded and that need to be erased. - std::vector opInstsToErase; + std::vector opInstsToErase; - bool foldOperation(OperationInst *op, + bool foldOperation(Instruction *op, SmallVectorImpl &existingConstants); - void visitInstruction(OperationInst *op); + void visitInstruction(Instruction *op); PassResult runOnFunction(Function *f) override; static char passID; @@ -49,7 +49,7 @@ char ConstantFold::passID = 0; /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -void ConstantFold::visitInstruction(OperationInst *op) { +void ConstantFold::visitInstruction(Instruction *op) { // If this operation is an AffineForOp, then fold the bounds. if (auto forOp = op->dyn_cast()) { constantFoldBounds(forOp); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 443e77509477..996416d92711 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -50,7 +50,7 @@ private: // Utility that looks up a list of value in the value remapping table. Returns // an empty vector if one of the values is not mapped yet. SmallVector - lookupValues(const llvm::iterator_range + lookupValues(const llvm::iterator_range &operands); // Converts the given function to the dialect using hooks defined in @@ -61,13 +61,13 @@ private: // from `valueRemapping` and the converted blocks from `blockRemapping`, and // passes them to `converter->rewriteTerminator` function defined in the // pattern, together with `builder`. - bool convertOpWithSuccessors(DialectOpConversion *converter, - OperationInst *op, FuncBuilder &builder); + bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op, + FuncBuilder &builder); // Converts an operation without successors. Extracts the converted operands // from `valueRemapping` and passes them to the `converter->rewrite` function // defined in the pattern, together with `builder`. - bool convertOp(DialectOpConversion *converter, OperationInst *op, + bool convertOp(DialectOpConversion *converter, Instruction *op, FuncBuilder &builder); // Converts a block by traversing its instructions sequentially, looking for @@ -104,8 +104,7 @@ private: } // end namespace mlir SmallVector impl::FunctionConversion::lookupValues( - const llvm::iterator_range - &operands) { + const llvm::iterator_range &operands) { SmallVector remapped; remapped.reserve(llvm::size(operands)); for (const Value *operand : operands) { @@ -118,7 +117,7 @@ SmallVector impl::FunctionConversion::lookupValues( } bool impl::FunctionConversion::convertOpWithSuccessors( - DialectOpConversion *converter, OperationInst *op, FuncBuilder &builder) { + DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) { SmallVector destinations; destinations.reserve(op->getNumSuccessors()); SmallVector operands = lookupValues(op->getOperands()); @@ -149,7 +148,7 @@ bool impl::FunctionConversion::convertOpWithSuccessors( } bool impl::FunctionConversion::convertOp(DialectOpConversion *converter, - OperationInst *op, + Instruction *op, FuncBuilder &builder) { auto operands = lookupValues(op->getOperands()); assert((!operands.empty() || op->getNumOperands() == 0) && @@ -174,24 +173,22 @@ bool impl::FunctionConversion::convertBlock( // Iterate over ops and convert them. for (Instruction &inst : *block) { - auto op = dyn_cast(&inst); - if (!op) { - inst.emitError("unsupported instruction (For/If)"); + if (inst.getNumBlockLists() != 0) { + inst.emitError("unsupported region instruction"); return true; } // Find the first matching conversion and apply it. bool converted = false; for (auto *conversion : conversions) { - if (!conversion->match(op)) + if (!conversion->match(&inst)) continue; - if (op->isTerminator() && op->getNumSuccessors() > 0) { - if (convertOpWithSuccessors(conversion, op, builder)) - return true; - } else { - if (convertOp(conversion, op, builder)) + if (inst.isTerminator() && inst.getNumSuccessors() > 0) { + if (convertOpWithSuccessors(conversion, &inst, builder)) return true; + } else if (convertOp(conversion, &inst, builder)) { + return true; } converted = true; break; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 2bbb32036c29..92ae37670986 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -157,8 +157,7 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of /// enclosing loop IVs of opInst (starting from the outermost) that the region /// is parametric on. -static bool getFullMemRefAsRegion(OperationInst *opInst, - unsigned numParamLoopIVs, +static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; if (auto loadOp = opInst->dyn_cast()) { @@ -563,7 +562,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { fastBufferMap.clear(); // Walk this range of instructions to gather all memory regions. - block->walk(begin, end, [&](OperationInst *opInst) { + block->walk(begin, end, [&](Instruction *opInst) { // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = opInst->dyn_cast()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 304331320ac9..d7d69e569e5f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -114,11 +114,11 @@ namespace { class LoopNestStateCollector : public InstWalker { public: SmallVector, 4> forOps; - SmallVector loadOpInsts; - SmallVector storeOpInsts; + SmallVector loadOpInsts; + SmallVector storeOpInsts; bool hasNonForRegion = false; - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { if (opInst->isa()) forOps.push_back(opInst->cast()); else if (opInst->getNumBlockLists() != 0) @@ -131,7 +131,7 @@ public: }; // TODO(b/117228571) Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(const OperationInst &op) { +static bool isMemRefDereferencingOp(const Instruction &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; @@ -153,9 +153,9 @@ public: // The top-level statment which is (or contains) loads/stores. Instruction *inst; // List of load operations. - SmallVector loads; + SmallVector loads; // List of store op insts. - SmallVector stores; + SmallVector stores; Node(unsigned id, Instruction *inst) : id(id), inst(inst) {} // Returns the load op count for 'memref'. @@ -258,16 +258,13 @@ public: for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast()->getMemRef(); auto *inst = memref->getDefiningInst(); - auto *opInst = dyn_cast_or_null(inst); - // Return false if 'memref' is a function argument. - if (opInst == nullptr) + // Return false if 'memref' is a block argument. + if (!inst) return true; // Return false if any use of 'memref' escapes the function. - for (auto &use : memref->getUses()) { - auto *user = dyn_cast(use.getOwner()); - if (!user || !isMemRefDereferencingOp(*user)) + for (auto &use : memref->getUses()) + if (!isMemRefDereferencingOp(*use.getOwner())) return true; - } } return false; } @@ -461,8 +458,8 @@ public: } // Adds ops in 'loads' and 'stores' to node at 'id'. - void addToNode(unsigned id, const SmallVectorImpl &loads, - const SmallVectorImpl &stores) { + void addToNode(unsigned id, const SmallVectorImpl &loads, + const SmallVectorImpl &stores) { Node *node = getNode(id); for (auto *loadOpInst : loads) node->loads.push_back(loadOpInst); @@ -509,7 +506,7 @@ bool MemRefDependenceGraph::init(Function *f) { DenseMap forToNodeMap; for (auto &inst : f->front()) { - if (auto forOp = cast(&inst)->dyn_cast()) { + if (auto forOp = inst.dyn_cast()) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; @@ -530,30 +527,28 @@ bool MemRefDependenceGraph::init(Function *f) { } forToNodeMap[&inst] = node.id; nodes.insert({node.id, node}); - } else if (auto *opInst = dyn_cast(&inst)) { - if (auto loadOp = opInst->dyn_cast()) { - // Create graph node for top-level load op. - Node node(nextNodeId++, &inst); - node.loads.push_back(opInst); - auto *memref = opInst->cast()->getMemRef(); - memrefAccesses[memref].insert(node.id); - nodes.insert({node.id, node}); - } else if (auto storeOp = opInst->dyn_cast()) { - // Create graph node for top-level store op. - Node node(nextNodeId++, &inst); - node.stores.push_back(opInst); - auto *memref = opInst->cast()->getMemRef(); - memrefAccesses[memref].insert(node.id); - nodes.insert({node.id, node}); - } else if (opInst->getNumBlockLists() != 0) { - // Return false if another region is found (not currently supported). - return false; - } else if (opInst->getNumResults() > 0 && !opInst->use_empty()) { - // Create graph node for top-level producer of SSA values, which - // could be used by loop nest nodes. - Node node(nextNodeId++, &inst); - nodes.insert({node.id, node}); - } + } else if (auto loadOp = inst.dyn_cast()) { + // Create graph node for top-level load op. + Node node(nextNodeId++, &inst); + node.loads.push_back(&inst); + auto *memref = inst.cast()->getMemRef(); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } else if (auto storeOp = inst.dyn_cast()) { + // Create graph node for top-level store op. + Node node(nextNodeId++, &inst); + node.stores.push_back(&inst); + auto *memref = inst.cast()->getMemRef(); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } else if (inst.getNumBlockLists() != 0) { + // Return false if another region is found (not currently supported). + return false; + } else if (inst.getNumResults() > 0 && !inst.use_empty()) { + // Create graph node for top-level producer of SSA values, which + // could be used by loop nest nodes. + Node node(nextNodeId++, &inst); + nodes.insert({node.id, node}); } } @@ -563,12 +558,11 @@ bool MemRefDependenceGraph::init(Function *f) { const Node &node = idAndNode.second; if (!node.loads.empty() || !node.stores.empty()) continue; - auto *opInst = cast(node.inst); + auto *opInst = node.inst; for (auto *value : opInst->getResults()) { for (auto &use : value->getUses()) { - auto *userOpInst = cast(use.getOwner()); SmallVector, 4> loops; - getLoopIVs(*userOpInst, &loops); + getLoopIVs(*use.getOwner(), &loops); if (loops.empty()) continue; assert(forToNodeMap.count(loops[0]->getInstruction()) > 0); @@ -619,7 +613,7 @@ public: LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { auto forOp = opInst->dyn_cast(); if (!forOp) return; @@ -627,8 +621,7 @@ public: auto *forInst = forOp->getInstruction(); auto *parentInst = forOp->getInstruction()->getParentInst(); if (parentInst != nullptr) { - assert(cast(parentInst)->isa() && - "Expected parent AffineForOp"); + assert(parentInst->isa() && "Expected parent AffineForOp"); // Add mapping to 'forOp' from its parent AffineForOp. stats->loopMap[parentInst].push_back(forOp); } @@ -637,8 +630,7 @@ public: unsigned count = 0; stats->opCountMap[forInst] = 0; for (auto &inst : *forOp->getBody()) { - if (!(cast(inst).isa() || - cast(inst).isa())) + if (!(inst.isa() || inst.isa())) ++count; } stats->opCountMap[forInst] = count; @@ -723,7 +715,7 @@ static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { // was encountered). // TODO(andydavis) Make this work with non-unit step loops. static bool buildSliceTripCountMap( - OperationInst *srcOpInst, ComputationSliceState *sliceState, + Instruction *srcOpInst, ComputationSliceState *sliceState, llvm::SmallDenseMap *tripCountMap) { SmallVector, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); @@ -755,10 +747,10 @@ static bool buildSliceTripCountMap( // adds them to 'dstLoads'. static void moveLoadsAccessingMemrefTo(Value *memref, - SmallVectorImpl *srcLoads, - SmallVectorImpl *dstLoads) { + SmallVectorImpl *srcLoads, + SmallVectorImpl *dstLoads) { dstLoads->clear(); - SmallVector srcLoadsToKeep; + SmallVector srcLoadsToKeep; for (auto *load : *srcLoads) { if (load->cast()->getMemRef() == memref) dstLoads->push_back(load); @@ -769,7 +761,7 @@ moveLoadsAccessingMemrefTo(Value *memref, } // Returns the innermost common loop depth for the set of operations in 'ops'. -static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { +static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { unsigned numOps = ops.size(); assert(numOps > 0); @@ -797,10 +789,10 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { // Returns the maximum loop depth at which no dependences between 'loadOpInsts' // and 'storeOpInsts' are satisfied. -static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, - ArrayRef storeOpInsts) { +static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, + ArrayRef storeOpInsts) { // Merge loads and stores into the same array. - SmallVector ops(loadOpInsts.begin(), loadOpInsts.end()); + SmallVector ops(loadOpInsts.begin(), loadOpInsts.end()); ops.append(storeOpInsts.begin(), storeOpInsts.end()); // Compute the innermost common loop depth for loads and stores. @@ -913,7 +905,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. static Value *createPrivateMemRef(OpPointer forOp, - OperationInst *srcStoreOpInst, + Instruction *srcStoreOpInst, unsigned dstLoopDepth, Optional fastMemorySpace, unsigned localBufSizeThreshold) { @@ -1061,9 +1053,9 @@ static uint64_t getSliceIterationCount( // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. -static bool isFusionProfitable(OperationInst *srcOpInst, - ArrayRef dstLoadOpInsts, - ArrayRef dstStoreOpInsts, +static bool isFusionProfitable(Instruction *srcOpInst, + ArrayRef dstLoadOpInsts, + ArrayRef dstStoreOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth) { LLVM_DEBUG({ @@ -1174,7 +1166,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1; for (auto *loadOp : dstLoadOpInsts) { auto *parentInst = loadOp->getParentInst(); - if (parentInst && cast(parentInst)->isa()) + if (parentInst && parentInst->isa()) computeCostMap[parentInst] = -1; } } @@ -1393,11 +1385,11 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!cast(dstNode->inst)->isa()) + if (!dstNode->inst->isa()) continue; - SmallVector loads = dstNode->loads; - SmallVector dstLoadOpInsts; + SmallVector loads = dstNode->loads; + SmallVector dstLoadOpInsts; DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. @@ -1426,7 +1418,7 @@ public: // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. - if (!cast(srcNode->inst)->isa()) + if (!srcNode->inst->isa()) continue; // Skip if 'srcNode' has more than one store to any memref. // TODO(andydavis) Support fusing multi-output src loop nests. @@ -1454,7 +1446,7 @@ public: // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); // Gather 'dstNode' store ops to 'memref'. - SmallVector dstStoreOpInsts; + SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) if (storeOpInst->cast()->getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); @@ -1472,8 +1464,7 @@ public: srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = - cast(dstNode->inst)->cast(); + auto dstAffineForOp = dstNode->inst->cast(); if (insertPointInst != dstAffineForOp->getInstruction()) { dstAffineForOp->getInstruction()->moveBefore(insertPointInst); } @@ -1488,7 +1479,7 @@ public: promoteIfSingleIteration(forOp); } // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector storesForMemref; + SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (storeOpInst->cast()->getMemRef() == memref) storesForMemref.push_back(storeOpInst); @@ -1541,9 +1532,8 @@ public: continue; // Use list expected to match the dep graph info. auto *inst = memref->getDefiningInst(); - auto *opInst = dyn_cast_or_null(inst); - if (opInst && opInst->isa()) - opInst->erase(); + if (inst && inst->isa()) + inst->erase(); } } }; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index f1ee7fd18533..8b368e5f182a 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -237,14 +237,13 @@ getTileableBands(Function *f, do { band.push_back(currInst); } while (currInst->getBody()->getInstructions().size() == 1 && - (currInst = cast(currInst->getBody()->front()) - .dyn_cast())); + (currInst = currInst->getBody()->front().dyn_cast())); bands->push_back(band); }; for (auto &block : *f) for (auto &inst : block) - if (auto forOp = cast(inst).dyn_cast()) + if (auto forOp = inst.dyn_cast()) getMaximalPerfectLoopNest(forOp); } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 9c9952d31ca3..b1e15ccb07b6 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -113,7 +113,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return hasInnerLoops; } - bool walkPostOrder(OperationInst *opInst) { + bool walkPostOrder(Instruction *opInst) { bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) @@ -140,7 +140,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { auto forOp = opInst->dyn_cast(); if (!forOp) return; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index d87f9d5dc149..74c54fde0475 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -100,8 +100,7 @@ PassResult LoopUnrollAndJam::runOnFunction(Function *f) { // any for Inst. auto &entryBlock = f->front(); if (!entryBlock.empty()) - if (auto forOp = - cast(entryBlock.front()).dyn_cast()) + if (auto forOp = entryBlock.front().dyn_cast()) runOnAffineForOp(forOp); return success(); @@ -149,12 +148,12 @@ bool mlir::loopUnrollJamByFactor(OpPointer forOp, void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !cast(it)->isa()) + while (it != End && !it->isa()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && cast(it)->isa()) + while (it != End && it->isa()) walk(&*it++); } } @@ -206,8 +205,7 @@ bool mlir::loopUnrollJamByFactor(OpPointer forOp, // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto cleanupAffineForOp = - cast(builder.clone(*forInst))->cast(); + auto cleanupAffineForOp = builder.clone(*forInst)->cast(); cleanupAffineForOp->setLowerBoundMap( getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder)); diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 08c8188fada9..88ccc90c18b7 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -616,23 +616,21 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walk([&](Instruction *inst) { - auto op = cast(inst); - if (op->isa() || op->isa() || - op->isa()) + if (inst->isa() || inst->isa() || + inst->isa()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) { - auto op = cast(inst); - if (auto ifOp = op->dyn_cast()) { + if (auto ifOp = inst->dyn_cast()) { if (lowerAffineIf(ifOp)) return failure(); - } else if (auto forOp = op->dyn_cast()) { + } else if (auto forOp = inst->dyn_cast()) { if (lowerAffineFor(forOp)) return failure(); - } else if (lowerAffineApply(op->cast())) { + } else if (lowerAffineApply(inst->cast())) { return failure(); } } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 7f1e9b157d80..63fb45db9c59 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -401,13 +401,12 @@ public: explicit VectorTransferExpander(MLIRContext *context) : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {} - PatternMatchResult match(OperationInst *op) const override { + PatternMatchResult match(Instruction *op) const override { if (m_Op().match(op)) return matchSuccess(); return matchFailure(); } - void rewriteOpInst(OperationInst *op, - MLFuncGlobalLoweringState *funcWiseState, + void rewriteOpInst(Instruction *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const override { VectorTransferRewriter( diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index f2dae11112b5..f55c2154f08f 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -246,8 +246,8 @@ static SmallVector delinearize(unsigned linearIndex, return res; } -static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, +static Instruction * +instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, DenseMap *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately @@ -391,7 +391,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, /// - constant splat is replaced by constant splat of `hwVectorType`. /// TODO(ntv): add more substitutions on a per-need basis. static SmallVector -materializeAttributes(OperationInst *opInst, VectorType hwVectorType) { +materializeAttributes(Instruction *opInst, VectorType hwVectorType) { SmallVector res; for (auto a : opInst->getAttrs()) { if (auto splat = a.second.dyn_cast()) { @@ -411,8 +411,8 @@ materializeAttributes(OperationInst *opInst, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, +static Instruction * +instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, DenseMap *substitutionsMap) { assert(!opInst->isa() && "Should call the function specialized for VectorTransferReadOp"); @@ -488,7 +488,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationInst * +static Instruction * instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -512,7 +512,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationInst * +static Instruction * instantiate(FuncBuilder *b, VectorTransferWriteOp *write, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -555,21 +555,20 @@ static bool instantiateMaterialization(Instruction *inst, // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); - auto *opInst = cast(inst); // AffineApplyOp are ignored: instantiating the proper vector op will take // care of AffineApplyOps by composing them properly. - if (opInst->isa()) { + if (inst->isa()) { return false; } - if (opInst->getNumBlockLists() != 0) + if (inst->getNumBlockLists() != 0) return inst->emitError("NYI path Op with region"); - if (auto write = opInst->dyn_cast()) { + if (auto write = inst->dyn_cast()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return clone == nullptr; } - if (auto read = opInst->dyn_cast()) { + if (auto read = inst->dyn_cast()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); if (!clone) { @@ -582,19 +581,19 @@ static bool instantiateMaterialization(Instruction *inst, // The only op with 0 results reaching this point must, by construction, be // VectorTransferWriteOps and have been caught above. Ops with >= 2 results // are not yet supported. So just support 1 result. - if (opInst->getNumResults() != 1) { + if (inst->getNumResults() != 1) { return inst->emitError("NYI: ops with != 1 results"); } - if (opInst->getResult(0)->getType() != state->superVectorType) { + if (inst->getResult(0)->getType() != state->superVectorType) { return inst->emitError("Op does not return a supervector."); } auto *clone = - instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap); + instantiate(&b, inst, state->hwVectorType, state->substitutionsMap); if (!clone) { return true; } state->substitutionsMap->insert( - std::make_pair(opInst->getResult(0), clone->getResult(0))); + std::make_pair(inst->getResult(0), clone->getResult(0))); return false; } @@ -645,7 +644,7 @@ static bool emitSlice(MaterializationState *state, } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG(cast((*slice)[0])->getFunction()->print(dbgs())); + LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* @@ -677,7 +676,7 @@ static bool emitSlice(MaterializationState *state, /// scope. /// TODO(ntv): please document return value. static bool materialize(Function *f, - const SetVector &terminators, + const SetVector &terminators, MaterializationState *state) { DenseSet seen; DominanceInfo domInfo(f); @@ -757,18 +756,17 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. auto filter = [subVectorType](const Instruction &inst) { - const auto &opInst = cast(inst); - if (!opInst.isa()) { + if (!inst.isa()) { return false; } - return matcher::operatesOnSuperVectors(opInst, subVectorType); + return matcher::operatesOnSuperVectors(inst, subVectorType); }; auto pat = Op(filter); SmallVector matches; pat.match(f, &matches); - SetVector terminators; + SetVector terminators; for (auto m : matches) { - terminators.insert(cast(m.getMatchedInstruction())); + terminators.insert(m.getMatchedInstruction()); } auto fail = materialize(f, terminators, &state); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index b9386c384dd8..b2b69dc7b6d7 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -75,12 +75,12 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker { PassResult runOnFunction(Function *f) override; - void visitInstruction(OperationInst *opInst); + void visitInstruction(Instruction *opInst); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; // Load op's whose results were replaced by those forwarded from stores. - std::vector loadOpsToErase; + std::vector loadOpsToErase; DominanceInfo *domInfo = nullptr; PostDominanceInfo *postDomInfo = nullptr; @@ -100,22 +100,22 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { - OperationInst *lastWriteStoreOp = nullptr; +void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) { + Instruction *lastWriteStoreOp = nullptr; auto loadOp = opInst->dyn_cast(); if (!loadOp) return; - OperationInst *loadOpInst = opInst; + Instruction *loadOpInst = opInst; // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. - SmallVector storeOps; + SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (InstOperand &use : loadOp->getMemRef()->getUses()) { - auto storeOp = cast(use.getOwner())->dyn_cast(); + auto storeOp = use.getOwner()->dyn_cast(); if (!storeOp) continue; auto *storeOpInst = storeOp->getInstruction(); @@ -131,11 +131,11 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { // and loadOp. // The list of store op candidates for forwarding - need to satisfy the // conditions listed at the top. - SmallVector fwdingCandidates; + SmallVector fwdingCandidates; // Store ops that have a dependence into the load (even if they aren't // forwarding candidates). Each forwarding candidate will be checked for a // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. - SmallVector depSrcStores; + SmallVector depSrcStores; for (auto *storeOpInst : storeOps) { MemRefAccess srcAccess(storeOpInst); MemRefAccess destAccess(loadOpInst); @@ -197,7 +197,7 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { // that postdominates all 'depSrcStores' (if such a store exists) is the // unique store providing the value to the load, i.e., provably the last // writer to that memref loc. - if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) { + if (llvm::all_of(depSrcStores, [&](Instruction *depStore) { return postDomInfo->postDominates(storeOpInst, depStore); })) { lastWriteStoreOp = storeOpInst; @@ -246,24 +246,22 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { // to do this as well, but we'll do it here since we collected these anyway. for (auto *memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. - OperationInst *defInst = memref->getDefiningInst(); + Instruction *defInst = memref->getDefiningInst(); if (!defInst || !defInst->isa()) // TODO(mlir-team): if the memref was returned by a 'call' instruction, we // could still erase it if the call had no side-effects. continue; if (std::any_of(memref->use_begin(), memref->use_end(), [&](InstOperand &use) { - auto *ownerInst = cast(use.getOwner()); + auto *ownerInst = use.getOwner(); return (!ownerInst->isa() && !ownerInst->isa()); })) continue; // Erase all stores, the dealloc, and the alloc on the memref. - for (auto it = memref->use_begin(), e = memref->use_end(); it != e;) { - auto &use = *(it++); - cast(use.getOwner())->erase(); - } + for (auto &use : llvm::make_early_inc_range(memref->getUses())) + use.getOwner()->erase(); defInst->erase(); } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 8d13800160d8..ba3be5e95f4c 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -61,7 +61,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() { // Returns the position of the tag memref operand given a DMA instruction. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const OperationInst &dmaInst) { +static unsigned getTagMemRefPos(const Instruction &dmaInst) { assert(dmaInst.isa() || dmaInst.isa()); if (dmaInst.isa()) { // Second to last operand. @@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkPostOrder([&](OperationInst *opInst) { + f->walkPostOrder([&](Instruction *opInst) { if (auto forOp = opInst->dyn_cast()) forOps.push_back(forOp); }); @@ -180,33 +180,26 @@ static bool checkTagMatch(OpPointer startOp, // Identify matching DMA start/finish instructions to overlap computation with. static void findMatchingStartFinishInsts( OpPointer forOp, - SmallVectorImpl> - &startWaitPairs) { + SmallVectorImpl> &startWaitPairs) { // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector, 4> outgoingDmaOps; for (auto &inst : *forOp->getBody()) { - auto *opInst = dyn_cast(&inst); - if (!opInst) - continue; OpPointer dmaStartOp; - if ((dmaStartOp = opInst->dyn_cast()) && + if ((dmaStartOp = inst.dyn_cast()) && dmaStartOp->isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } - SmallVector dmaStartInsts, dmaFinishInsts; + SmallVector dmaStartInsts, dmaFinishInsts; for (auto &inst : *forOp->getBody()) { - auto *opInst = dyn_cast(&inst); - if (!opInst) - continue; // Collect DMA finish instructions. - if (opInst->isa()) { - dmaFinishInsts.push_back(opInst); + if (inst.isa()) { + dmaFinishInsts.push_back(&inst); continue; } OpPointer dmaStartOp; - if (!(dmaStartOp = opInst->dyn_cast())) + if (!(dmaStartOp = inst.dyn_cast())) continue; // Only DMAs incoming into higher memory spaces are pipelined for now. // TODO(bondhugula): handle outgoing DMA pipelining. @@ -236,7 +229,7 @@ static void findMatchingStartFinishInsts( } } if (!escapingUses) - dmaStartInsts.push_back(opInst); + dmaStartInsts.push_back(&inst); } // For each start instruction, we look for a matching finish instruction. @@ -262,7 +255,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { return success(); } - SmallVector, 4> startWaitPairs; + SmallVector, 4> startWaitPairs; findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { @@ -335,7 +328,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { } else { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. - SmallVector affineApplyInsts; + SmallVector affineApplyInsts; SmallVector operands(dmaStartInst->getOperands()); getReachableAffineApplyOps(operands, affineApplyInsts); for (const auto *inst : affineApplyInsts) { @@ -356,13 +349,13 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { for (auto &inst : *forOp->getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); shifts[s++] = instShiftMap[&inst]; - LLVM_DEBUG( - // Tagging instructions with shifts for debugging purposes. - if (auto *opInst = dyn_cast(&inst)) { - FuncBuilder b(opInst); - opInst->setAttr(b.getIdentifier("shift"), - b.getI64IntegerAttr(shifts[s - 1])); - }); + + // Tagging instructions with shifts for debugging purposes. + LLVM_DEBUG({ + FuncBuilder b(&inst); + inst.setAttr(b.getIdentifier("shift"), + b.getI64IntegerAttr(shifts[s - 1])); + }); } if (!isInstwiseShiftValid(forOp, shifts)) { diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index a9fcfc5bd116..29509911e31a 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -64,7 +64,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { } PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walk([&](OperationInst *opInst) { + f->walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 790f971bb58a..45c57e2f3070 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -38,13 +38,13 @@ public: worklist.reserve(64); // Add all operations to the worklist. - fn->walk([&](OperationInst *inst) { addToWorklist(inst); }); + fn->walk([&](Instruction *inst) { addToWorklist(inst); }); } /// Perform the rewrites. void simplifyFunction(); - void addToWorklist(OperationInst *op) { + void addToWorklist(Instruction *op) { // Check to see if the worklist already contains this op. if (worklistMap.count(op)) return; @@ -53,7 +53,7 @@ public: worklist.push_back(op); } - OperationInst *popFromWorklist() { + Instruction *popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); @@ -65,7 +65,7 @@ public: /// If the specified operation is in the worklist, remove it. If not, this is /// a no-op. - void removeFromWorklist(OperationInst *op) { + void removeFromWorklist(Instruction *op) { auto it = worklistMap.find(op); if (it != worklistMap.end()) { assert(worklist[it->second] == op && "malformed worklist data structure"); @@ -77,7 +77,7 @@ public: protected: // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. - OperationInst *createOperation(const OperationState &state) override { + Instruction *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); addToWorklist(result); return result; @@ -85,20 +85,18 @@ protected: // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. - void notifyOperationRemoved(OperationInst *op) override { + void notifyOperationRemoved(Instruction *op) override { removeFromWorklist(op); } // When the root of a pattern is about to be replaced, it can trigger // simplifications to its users - make sure to add them to the worklist // before the root is changed. - void notifyRootReplaced(OperationInst *op) override { + void notifyRootReplaced(Instruction *op) override { for (auto *result : op->getResults()) // TODO: Add a result->getUsers() iterator. - for (auto &user : result->getUses()) { - if (auto *op = dyn_cast(user.getOwner())) - addToWorklist(op); - } + for (auto &user : result->getUses()) + addToWorklist(user.getOwner()); // TODO: Walk the operand list dropping them as we go. If any of them // drop to zero uses, then add them to the worklist to allow them to be @@ -116,13 +114,13 @@ private: /// need to be revisited, plus their index in the worklist. This allows us to /// efficiently remove operations from the worklist when they are erased from /// the function, even if they aren't the root of a pattern. - std::vector worklist; - DenseMap worklistMap; + std::vector worklist; + DenseMap worklistMap; /// As part of canonicalization, we move constants to the top of the entry /// block of the current function and de-duplicate them. This keeps track of /// constants we have done this for. - DenseMap, OperationInst *> uniquedConstants; + DenseMap, Instruction *> uniquedConstants; }; }; // end anonymous namespace @@ -229,10 +227,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() { // revisit them. // // TODO: Add a result->getUsers() iterator. - for (auto &operand : op->getResult(i)->getUses()) { - if (auto *op = dyn_cast(operand.getOwner())) - addToWorklist(op); - } + for (auto &operand : op->getResult(i)->getUses()) + addToWorklist(operand.getOwner()); res->replaceAllUsesWith(cstValue); } @@ -267,10 +263,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() { if (res->use_empty()) // ignore dead uses. continue; - for (auto &operand : op->getResult(i)->getUses()) { - if (auto *op = dyn_cast(operand.getOwner())) - addToWorklist(op); - } + for (auto &operand : op->getResult(i)->getUses()) + addToWorklist(operand.getOwner()); res->replaceAllUsesWith(resultValues[i]); } } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 153557de04a4..5bf17989befc 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(OpPointer forOp) { // Replaces all IV uses to its single iteration value. auto *iv = forOp->getInductionVar(); - OperationInst *forInst = forOp->getInstruction(); + Instruction *forInst = forOp->getInstruction(); if (!iv->use_empty()) { if (forOp->hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); @@ -135,7 +135,7 @@ bool mlir::promoteIfSingleIteration(OpPointer forOp) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - f->walkPostOrder([](OperationInst *inst) { + f->walkPostOrder([](Instruction *inst) { if (auto forOp = inst->dyn_cast()) promoteIfSingleIteration(forOp); }); @@ -394,11 +394,10 @@ bool mlir::loopUnrollByFactor(OpPointer forOp, return false; // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - OperationInst *forInst = forOp->getInstruction(); + Instruction *forInst = forOp->getInstruction(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); - auto cleanupForInst = - cast(builder.clone(*forInst))->cast(); + auto cleanupForInst = builder.clone(*forInst)->cast(); auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder); assert(clLbMap && "cleanup loop lower bound map for single result bound maps can " diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 879a4f4b585f..524e8d542f5b 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -37,7 +37,7 @@ using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) -static bool isMemRefDereferencingOp(const OperationInst &op) { +static bool isMemRefDereferencingOp(const Instruction &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; @@ -76,12 +76,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, std::make_unique(postDomInstFilter->getFunction()); // The ops where memref replacement succeeds are replaced with new ones. - SmallVector opsToErase; + SmallVector opsToErase; // Walk all uses of old memref. Operation using the memref gets replaced. - for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { - InstOperand &use = *(it++); - auto *opInst = cast(use.getOwner()); + for (auto &use : llvm::make_early_inc_range(oldMemRef->getUses())) { + auto *opInst = use.getOwner(); // Skip this use if it's not dominated by domInstFilter. if (domInstFilter && !domInfo->dominates(domInstFilter, opInst)) @@ -217,8 +216,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, /// uses besides this opInst; otherwise returns the list of affine_apply /// operations created in output argument `sliceOps`. void mlir::createAffineComputationSlice( - OperationInst *opInst, - SmallVectorImpl> *sliceOps) { + Instruction *opInst, SmallVectorImpl> *sliceOps) { // Collect all operands that are results of affine apply ops. SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); @@ -230,7 +228,7 @@ void mlir::createAffineComputationSlice( } // Gather sequence of AffineApplyOps reachable from 'subOperands'. - SmallVector affineApplyOps; + SmallVector affineApplyOps; getReachableAffineApplyOps(subOperands, affineApplyOps); // Skip transforming if there are no affine maps to compose. if (affineApplyOps.empty()) @@ -341,8 +339,7 @@ bool mlir::constantFoldBounds(OpPointer forInst) { } void mlir::remapFunctionAttrs( - OperationInst &op, - const DenseMap &remappingTable) { + Instruction &op, const DenseMap &remappingTable) { for (auto attr : op.getAttrs()) { // Do the remapping, if we got the same thing back, then it must contain // functions that aren't getting remapped. diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index a9b9752ef514..7d51637a6e10 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -110,17 +110,13 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { // Only filter instructions that operate on a strict super-vector and have one // return. This makes testing easier. auto filter = [subVectorType](const Instruction &inst) { - auto *opInst = dyn_cast(&inst); - if (!opInst) { - return false; - } assert(subVectorType.getElementType() == Type::getF32(subVectorType.getContext()) && "Only f32 supported for now"); - if (!matcher::operatesOnSuperVectors(*opInst, subVectorType)) { + if (!matcher::operatesOnSuperVectors(inst, subVectorType)) { return false; } - if (opInst->getNumResults() != 1) { + if (inst.getNumResults() != 1) { return false; } return true; @@ -129,7 +125,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { SmallVector matches; pat.match(f, &matches); for (auto m : matches) { - auto *opInst = cast(m.getMatchedInstruction()); + auto *opInst = m.getMatchedInstruction(); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -159,8 +155,7 @@ static NestedPattern patternTestSlicingOps() { using matcher::Op; // Match all OpInstructions with the kTestSlicingOpName name. auto filter = [](const Instruction &inst) { - const auto &opInst = cast(inst); - return opInst.getName().getStringRef() == kTestSlicingOpName; + return inst.getName().getStringRef() == kTestSlicingOpName; }; return Op(filter); } @@ -209,8 +204,7 @@ void VectorizerTestPass::testSlicing(Function *f) { } static bool customOpWithAffineMapAttribute(const Instruction &inst) { - const auto &opInst = cast(inst); - return opInst.getName().getStringRef() == + return inst.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -222,7 +216,7 @@ void VectorizerTestPass::testComposeMaps(Function *f) { SmallVector maps; maps.reserve(matches.size()); for (auto m : llvm::reverse(matches)) { - auto *opInst = cast(m.getMatchedInstruction()); + auto *opInst = m.getMatchedInstruction(); auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast() .getValue(); @@ -236,13 +230,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) { } static bool affineApplyOp(const Instruction &inst) { - const auto &opInst = cast(inst); - return opInst.isa(); + return inst.isa(); } static bool singleResultAffineApplyOpWithoutUses(const Instruction &inst) { - const auto &opInst = cast(inst); - auto app = opInst.dyn_cast(); + auto app = inst.dyn_cast(); return app && app->use_empty(); } @@ -259,8 +251,7 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { SmallVector matches; pattern.match(f, &matches); for (auto m : matches) { - auto app = - cast(m.getMatchedInstruction())->cast(); + auto app = m.getMatchedInstruction()->cast(); FuncBuilder b(m.getMatchedInstruction()); SmallVector operands(app->getOperands()); makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 661861dcfd48..5a8d5d246617 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -723,22 +723,22 @@ namespace { struct VectorizationState { /// Adds an entry of pre/post vectorization instructions in the state. - void registerReplacement(OperationInst *key, OperationInst *value); + void registerReplacement(Instruction *key, Instruction *value); /// When the current vectorization pattern is successful, this erases the /// instructions that were marked for erasure in the proper order and resets /// the internal state for the next pattern. void finishVectorizationPattern(); - // In-order tracking of original OperationInst that have been vectorized. + // In-order tracking of original Instruction that have been vectorized. // Erase in reverse order. - SmallVector toErase; - // Set of OperationInst that have been vectorized (the values in the + SmallVector toErase; + // Set of Instruction that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in // particular to filter the instructions that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. - DenseSet vectorizedSet; - // Map of old scalar OperationInst to new vectorized OperationInst. - DenseMap vectorizationMap; + DenseSet vectorizedSet; + // Map of old scalar Instruction to new vectorized Instruction. + DenseMap vectorizationMap; // Map of old scalar Value to new vectorized Value. DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. @@ -747,17 +747,17 @@ struct VectorizationState { // vectorizeOperations function. They consist of the subset of load operations // that have been vectorized. They can be retrieved from `vectorizationMap` // but it is convenient to keep track of them in a separate data structure. - DenseSet roots; + DenseSet roots; // Terminator instructions for the worklist in the vectorizeOperations // function. They consist of the subset of store operations that have been // vectorized. They can be retrieved from `vectorizationMap` but it is // convenient to keep track of them in a separate data structure. Since they // do not necessarily belong to use-def chains starting from loads (e.g // storing a constant), we need to handle them in a post-pass. - DenseSet terminators; + DenseSet terminators; // Checks that the type of `inst` is StoreOp and adds it to the terminators // set. - void registerTerminator(OperationInst *inst); + void registerTerminator(Instruction *inst); private: void registerReplacement(const Value *key, Value *value); @@ -765,8 +765,8 @@ private: } // end namespace -void VectorizationState::registerReplacement(OperationInst *key, - OperationInst *value) { +void VectorizationState::registerReplacement(Instruction *key, + Instruction *value) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: "); LLVM_DEBUG(key->print(dbgs())); LLVM_DEBUG(dbgs() << " into "); @@ -785,7 +785,7 @@ void VectorizationState::registerReplacement(OperationInst *key, } } -void VectorizationState::registerTerminator(OperationInst *inst) { +void VectorizationState::registerTerminator(Instruction *inst) { assert(inst->isa() && "terminator must be a StoreOp"); assert(terminators.count(inst) == 0 && "terminator was already inserted previously"); @@ -867,17 +867,16 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, if (!matcher::isLoadOrStore(inst)) { return false; } - auto *opInst = cast(&inst); - return state->vectorizationMap.count(opInst) == 0 && - state->vectorizedSet.count(opInst) == 0 && - state->roots.count(opInst) == 0 && - state->terminators.count(opInst) == 0; + return state->vectorizationMap.count(&inst) == 0 && + state->vectorizedSet.count(&inst) == 0 && + state->roots.count(&inst) == 0 && + state->terminators.count(&inst) == 0; }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector loadAndStoresMatches; loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { - auto *opInst = cast(ls.getMatchedInstruction()); + auto *opInst = ls.getMatchedInstruction(); auto load = opInst->dyn_cast(); auto store = opInst->dyn_cast(); LLVM_DEBUG(opInst->print(dbgs())); @@ -900,7 +899,7 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { return [fastestVaryingMemRefDimension](const Instruction &forInst) { - auto loop = cast(forInst).cast(); + auto loop = forInst.cast(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -915,7 +914,7 @@ static bool vectorizeNonRoot(ArrayRef matches, /// recursively in DFS post-order. static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { auto *loopInst = oneMatch.getMatchedInstruction(); - auto loop = cast(loopInst)->cast(); + auto loop = loopInst->cast(); auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -977,15 +976,14 @@ static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant, Location loc = inst->getLoc(); auto vectorType = type.cast(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpInst = cast(constant.getInstruction()); + auto *constantOpInst = constant.getInstruction(); OperationState state( b.getContext(), loc, constantOpInst->getName().getStringRef(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)}); - auto *splat = cast(b.createOperation(state)); - return splat->getResult(0); + return b.createOperation(state)->getResult(0); } /// Returns a uniqu'ed VectorType. @@ -997,8 +995,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } - auto *definingOpInst = cast(v->getDefiningInst()); - if (state.vectorizedSet.count(definingOpInst) > 0) { + if (state.vectorizedSet.count(v->getDefiningInst()) > 0) { return v->getType().cast(); } return VectorType::get(state.strategy->vectorSizes, v->getType()); @@ -1029,9 +1026,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); - auto *definingInstruction = cast(operand->getDefiningInst()); // 1. If this value has already been vectorized this round, we are done. - if (state->vectorizedSet.count(definingInstruction) > 0) { + if (state->vectorizedSet.count(operand->getDefiningInst()) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); return operand; } @@ -1062,7 +1058,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, return nullptr; }; -/// Encodes OperationInst-specific behavior for vectorization. In general we +/// Encodes Instruction-specific behavior for vectorization. In general we /// assume that all operands of an op must be vectorized but this is not always /// true. In the future, it would be nice to have a trait that describes how a /// particular operation vectorizes. For now we implement the case distinction @@ -1071,9 +1067,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, /// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. -static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, - OperationInst *opInst, - VectorizationState *state) { +static Instruction *vectorizeOneInstruction(FuncBuilder *b, Instruction *opInst, + VectorizationState *state) { // Sanity checks. assert(!opInst->isa() && "all loads must have already been fully vectorized independently"); @@ -1094,7 +1089,7 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( opInst->getLoc(), vectorValue, memRef, indices, permutationMap); - auto *res = cast(transfer->getInstruction()); + auto *res = transfer->getInstruction(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminators" (i.e. StoreOps) are erased on the spot. opInst->erase(); @@ -1119,8 +1114,8 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, // Create a clone of the op with the proper operands and return types. // TODO(ntv): The following assumes there is always an op with a fixed // name that works both in scalar mode and vector mode. - // TODO(ntv): Is it worth considering an OperationInst.clone operation - // which changes the type so we can promote an OperationInst with less + // TODO(ntv): Is it worth considering an Instruction.clone operation + // which changes the type so we can promote an Instruction with less // boilerplate? OperationState newOp(b->getContext(), opInst->getLoc(), opInst->getName().getStringRef(), operands, types, @@ -1129,22 +1124,22 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, return b->createOperation(newOp); } -/// Iterates over the OperationInst in the loop and rewrites them using their +/// Iterates over the Instruction in the loop and rewrites them using their /// vectorized counterpart by: -/// 1. iteratively building a worklist of uses of the OperationInst vectorized +/// 1. iteratively building a worklist of uses of the Instruction vectorized /// so far by this pattern; -/// 2. for each OperationInst in the worklist, create the vector form of this +/// 2. for each Instruction in the worklist, create the vector form of this /// operation and replace all its uses by the vectorized form. For this step, /// the worklist must be traversed in order; /// 3. verify that all operands of the newly vectorized operation have been /// vectorized by this pattern. static bool vectorizeOperations(VectorizationState *state) { // 1. create initial worklist with the uses of the roots. - SetVector worklist; - auto insertUsesOf = [&worklist, state](OperationInst *vectorized) { + SetVector worklist; + auto insertUsesOf = [&worklist, state](Instruction *vectorized) { for (auto *r : vectorized->getResults()) for (auto &u : r->getUses()) { - auto *inst = cast(u.getOwner()); + auto *inst = u.getOwner(); // Don't propagate to terminals, a separate pass is needed for those. // TODO(ntv)[b/119759136]: use isa<> once Op is implemented. if (state->terminators.count(inst) > 0) { @@ -1166,7 +1161,7 @@ static bool vectorizeOperations(VectorizationState *state) { // 2. Create vectorized form of the instruction. // Insert it just before inst, on success register inst as replaced. FuncBuilder b(inst); - auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state); + auto *vectorizedInst = vectorizeOneInstruction(&b, inst, state); if (!vectorizedInst) { return true; } @@ -1179,7 +1174,7 @@ static bool vectorizeOperations(VectorizationState *state) { // 4. Augment the worklist with uses of the instruction we just vectorized. // This preserves the proper order in the worklist. - apply(insertUsesOf, ArrayRef{inst}); + apply(insertUsesOf, ArrayRef{inst}); } return false; } @@ -1189,8 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) { /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto loop = - cast(m.getMatchedInstruction())->cast(); + auto loop = m.getMatchedInstruction()->cast(); VectorizationState state; state.strategy = strategy; @@ -1207,8 +1201,7 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { } auto *loopInst = loop->getInstruction(); FuncBuilder builder(loopInst); - auto clonedLoop = - cast(builder.clone(*loopInst))->cast(); + auto clonedLoop = builder.clone(*loopInst)->cast(); auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match @@ -1248,12 +1241,12 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { } // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { + auto vectorizeOrFail = [&fail, &state](Instruction *inst) { if (fail) { return; } FuncBuilder b(inst); - auto *res = vectorizeOneOperationInst(&b, inst, &state); + auto *res = vectorizeOneInstruction(&b, inst, &state); if (res == nullptr) { fail = true; }