diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 17fa57d341ca..363d3ffcdcf3 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -80,9 +80,9 @@ public: if (!owner) return llvm::None; if (OpOperand *operand = opView.dyn_cast()) - return owner.getIndexingMap(operand->getOperandNumber()); - return owner.getOutputIndexingMap( - opView.get().cast().getResultNumber()); + return owner.getTiedIndexingMap(operand); + return owner.getTiedIndexingMap(owner.getOutputOperand( + opView.get().cast().getResultNumber())); } // Return the operand number if the `opView` is an OpOperand *. Otherwise // return llvm::None. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 3e37979b68ec..5423a158a80c 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -165,46 +165,46 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation() << " and " << *dst.getOperation() << "\n"); if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { - for (OpOperand &dstOpOperand : dst.getInputOpOperands()) { + for (OpOperand *dstOpOperand : dst.getInputOperands()) { // Check if the operand is defined by the src. - auto definingOp = dstOpOperand.get().getDefiningOp(); + auto definingOp = dstOpOperand->get().getDefiningOp(); if (definingOp && definingOp == src) - addDependenceElem(DependenceType::RAW, dstOpOperand.get(), - &dstOpOperand); + addDependenceElem(DependenceType::RAW, dstOpOperand->get(), + dstOpOperand); } - for (OpOperand &dstOpOperand : dst.getOutputOpOperands()) { + for (OpOperand *dstOpOperand : dst.getOutputOperands()) { // Check if the operand is defined by the src. - auto definingOp = dstOpOperand.get().getDefiningOp(); + auto definingOp = dstOpOperand->get().getDefiningOp(); if (definingOp && definingOp == src) { - if (dst.isInitTensor(&dstOpOperand)) { - addDependenceElem(DependenceType::RAW, dstOpOperand.get(), - &dstOpOperand); + if (dst.isInitTensor(dstOpOperand)) { + addDependenceElem(DependenceType::RAW, dstOpOperand->get(), + dstOpOperand); } - addDependenceElem(DependenceType::WAW, dstOpOperand.get(), - &dstOpOperand); + addDependenceElem(DependenceType::WAW, dstOpOperand->get(), + dstOpOperand); } } return; } assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && "unhandled dependence tracking for mixed buffer/tensor operations"); - for (OpOperand *srcOpOperand : src.getOutputBuffersOpOperands()) { // W + for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W // RAW graph - for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias + for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R + if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); // WAW graph - for (OpOperand *dstOpOperand : dst.getOutputBuffersOpOperands()) // W + for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); } - for (OpOperand *srcOpOperand : src.getInputBuffersOpOperands()) { // R + for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R // RAR graph - for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias + for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R + if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); // WAR graph - for (OpOperand *dstOpOperand : dst.getOutputBuffersOpOperands()) // W + for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); }