//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" using namespace mlir; using namespace mlir::bufferization; using namespace mlir::tensor; namespace mlir { namespace tensor { namespace { struct CastOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {{op->getResult(0), BufferRelation::Equivalent}}; } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto castOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( castOp.getSource(), options, invocationStack); if (failed(maybeSrcBufferType)) return failure(); Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref // type in case the input is an unranked tensor type. // Case 1: Casting an unranked tensor if (isa(castOp.getSource().getType())) { // When casting to a ranked tensor, we cannot infer any static offset or // strides from the source. Assume fully dynamic. return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); } // Case 2: Casting to an unranked tensor type if (isa(castOp.getType())) { return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); } // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not // change. auto rankedResultType = cast(castOp.getType()); return MemRefType::get( rankedResultType.getShape(), rankedResultType.getElementType(), llvm::cast(*maybeSrcBufferType).getLayout(), memorySpace); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto castOp = cast(op); // The result buffer still has the old (pre-cast) type. FailureOr resultBuffer = getBuffer(rewriter, castOp.getSource(), options); if (failed(resultBuffer)) return failure(); // Compute the new type. auto resultMemRefType = bufferization::getBufferType(castOp.getResult(), options); if (failed(resultMemRefType)) return failure(); if (resultBuffer->getType() == *resultMemRefType) { // This cast is a no-op. replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); return success(); } // Replace the op with a memref.cast. assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), *resultMemRefType) && "CallOp::bufferize: cast incompatible"); replaceOpWithNewBufferizedOp( rewriter, op, *resultMemRefType, *resultBuffer); return success(); } }; /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. struct CollapseShapeOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // TODO: CollapseShapeOp may allocate at runtime. return {{op->getOpResult(0), BufferRelation::Equivalent}}; } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto collapseShapeOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( collapseShapeOp.getSrc(), options, invocationStack); if (failed(maybeSrcBufferType)) return failure(); auto srcBufferType = llvm::cast(*maybeSrcBufferType); bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( srcBufferType, collapseShapeOp.getReassociationIndices()); if (!canBeCollapsed) { // If dims cannot be collapsed, this op bufferizes to a new allocation. RankedTensorType tensorResultType = collapseShapeOp.getResultType(); return bufferization::getMemRefTypeWithStaticIdentityLayout( tensorResultType, srcBufferType.getMemorySpace()); } return memref::CollapseShapeOp::computeCollapsedType( srcBufferType, collapseShapeOp.getReassociationIndices()); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto collapseShapeOp = cast(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); FailureOr maybeBuffer = getBuffer(rewriter, collapseShapeOp.getSrc(), options); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; auto bufferType = cast(buffer.getType()); if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. MemRefType resultType; if (bufferType.getLayout().isIdentity()) { // Standard layout: result type has no offset. MemRefLayoutAttrInterface layout; resultType = MemRefType::get({}, tensorResultType.getElementType(), layout, bufferType.getMemorySpace()); } else { // Source memref has a layout map: result type has the same offset as // the source type. SmallVector strides; int64_t offset; if (failed(getStridesAndOffset(bufferType, strides, offset))) return failure(); resultType = MemRefType::get( {}, tensorResultType.getElementType(), StridedLayoutAttr::get(op->getContext(), offset, {}), bufferType.getMemorySpace()); } replaceOpWithNewBufferizedOp( rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); return success(); } // If the dims are not collapsible (due to an incompatible source layout // map), force an out-of-place bufferization, i.e., a buffer copy. This // newly allocated buffer will have no layout map and thus be collapsible. bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( bufferType, collapseShapeOp.getReassociationIndices()); if (!canBeCollapsed) { // TODO: Create alloc_tensor ops during TensorCopyInsertion. AnalysisState analysisState(options); FailureOr tensorAlloc = allocateTensorForShapedValue( rewriter, op->getLoc(), collapseShapeOp.getSrc(), options); if (failed(tensorAlloc)) return failure(); auto memrefType = MemRefType::get(collapseShapeOp.getSrcType().getShape(), collapseShapeOp.getSrcType().getElementType(), AffineMap(), bufferType.getMemorySpace()); buffer = rewriter.create( op->getLoc(), memrefType, *tensorAlloc); } // Result type is inferred by the builder. replaceOpWithNewBufferizedOp( rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); return success(); } }; /// Bufferization of tensor.dim. Replace with memref.dim. struct DimOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // The op reads the tensor's metadata but not its contents. return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto dimOp = cast(op); FailureOr v = getBuffer(rewriter, dimOp.getSource(), options); if (failed(v)) return failure(); replaceOpWithNewBufferizedOp(rewriter, op, *v, dimOp.getIndex()); return success(); } }; /// Bufferization of tensor.empty. This op does not bufferize, but we need an /// interface implementation, so that the result of this op is considered /// "writable" (default impl. of `isWritable`). Results of ops that do not /// implement `BufferizableOpInterface` are not writable. struct EmptyOpInterface : public BufferizableOpInterface::ExternalModel { bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult, const AnalysisState &state) const { // The returned tensor does not have specified contents. return false; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { if (op->getUses().empty()) { rewriter.eraseOp(op); return success(); } // tensor.empty ops are used to indicate the shape of a tensor. They have // no defined contents and cannot be bufferized. However, they can be // converted to bufferization.alloc_tensor ops, which then bufferize to an // allocation (--empty-tensor-to-alloc-tensor). return op->emitOpError("cannot be bufferized, but can be converted to " "bufferization.alloc_tensor"); } }; /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. struct ExpandShapeOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {{op->getOpResult(0), BufferRelation::Equivalent}}; } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto expandShapeOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( expandShapeOp.getSrc(), options, invocationStack); if (failed(maybeSrcBufferType)) return failure(); auto srcBufferType = llvm::cast(*maybeSrcBufferType); auto maybeResultType = memref::ExpandShapeOp::computeExpandedType( srcBufferType, expandShapeOp.getResultType().getShape(), expandShapeOp.getReassociationIndices()); if (failed(maybeResultType)) return failure(); return *maybeResultType; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto expandShapeOp = cast(op); auto tensorResultType = expandShapeOp.getResultType(); FailureOr buffer = getBuffer(rewriter, expandShapeOp.getSrc(), options); if (failed(buffer)) return failure(); // Memref result type is inferred by the builder based on reassociation // indices and result shape. replaceOpWithNewBufferizedOp( rewriter, op, tensorResultType.getShape(), *buffer, expandShapeOp.getReassociationIndices()); return success(); } }; /// Bufferization of tensor.extract_slice. Replace with memref.subview. struct ExtractSliceOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {{op->getOpResult(0), BufferRelation::Unknown}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto extractSliceOp = cast(op); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); Location loc = extractSliceOp.getLoc(); // Get source buffer. FailureOr srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options); if (failed(srcMemref)) return failure(); // Take a subview of the source buffer. auto resultMemrefType = bufferization::getBufferType(extractSliceOp.getResult(), options); if (failed(resultMemrefType)) return failure(); Value subView = rewriter.create( loc, llvm::cast(*resultMemrefType), *srcMemref, mixedOffsets, mixedSizes, mixedStrides); replaceOpWithBufferizedValues(rewriter, op, subView); return success(); } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto extractSliceOp = cast(op); assert(value == extractSliceOp.getResult() && "invalid value"); auto srcMemrefType = bufferization::getBufferType( extractSliceOp.getSource(), options, invocationStack); if (failed(srcMemrefType)) return failure(); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); return cast(memref::SubViewOp::inferRankReducedResultType( extractSliceOp.getType().getShape(), llvm::cast(*srcMemrefType), mixedOffsets, mixedSizes, mixedStrides)); } }; /// Bufferization of tensor.extract. Replace with memref.load. struct ExtractOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto extractOp = cast(op); FailureOr srcMemref = getBuffer(rewriter, extractOp.getTensor(), options); if (failed(srcMemref)) return failure(); replaceOpWithNewBufferizedOp(rewriter, op, *srcMemref, extractOp.getIndices()); return success(); } }; // Implements backtracking to traverse indices of the output buffer while // iterating over op.elements(). static void createStores(RewriterBase &rewriter, Location loc, int dim, Value buffer, ArrayRef shape, ArrayRef constants, OperandRange::iterator &elementIt, SmallVectorImpl &indices) { if (dim == static_cast(shape.size()) - 1) { for (int i = 0; i < shape.back(); ++i) { indices.back() = constants[i]; rewriter.create(loc, *elementIt, buffer, indices); ++elementIt; } return; } for (int i = 0; i < shape[dim]; ++i) { indices[dim] = constants[i]; createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, indices); } } /// Bufferization of tensor.from_elements. struct FromElementsOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto fromElementsOp = cast(op); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) return op->emitError("memory space not implemented yet"); // Allocate a buffer for the result. Location loc = op->getLoc(); auto tensorType = cast(fromElementsOp.getType()); auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. FailureOr tensorAlloc = allocateTensorForShapedValue( rewriter, loc, fromElementsOp.getResult(), options, /*copy=*/false); if (failed(tensorAlloc)) return failure(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); Value buffer = rewriter.create( op->getLoc(), memrefType, *tensorAlloc); // Case: tensor<0xelem_type>. if (fromElementsOp.getElements().empty()) { replaceOpWithBufferizedValues(rewriter, op, buffer); return success(); } // Case: tensor. if (shape.empty()) { rewriter.create( loc, fromElementsOp.getElements().front(), buffer); replaceOpWithBufferizedValues(rewriter, op, buffer); return success(); } // Create constants for the range of possible indices [0, max{shape_i}). auto maxDim = *std::max_element(shape.begin(), shape.end()); SmallVector constants; constants.reserve(maxDim); for (int i = 0; i < maxDim; ++i) constants.push_back(rewriter.create(loc, i)); // Traverse all `elements` and create `memref.store` ops. auto elementIt = fromElementsOp.getElements().begin(); SmallVector indices(tensorType.getRank(), constants[0]); createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, indices); replaceOpWithBufferizedValues(rewriter, op, buffer); return success(); } }; /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim). /// Such ops are lowered to linalg.map with the given tensor as a destination. /// /// Example: /// ``` /// %r = tensor.generate %x, %y { /// ^bb0(%arg0: index, %arg1: index): /// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index) /// tensor.yield %0 : index /// } : tensor /// ``` /// /// Is lowered to: /// ``` /// linalg.map ins() outs(%dest) { /// %d0 = linalg.index 0 : index /// %d1 = linalg.index 1 : index /// %0 = "some_op"(%d0, %d1) : (index, index) -> (index) /// linalg.yield %0 : index /// } /// ``` static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, Value tensorDestination, ValueRange dynamicSizes, Region &generateBody) { assert(generateBody.hasOneBlock() && "expected body with single block"); auto tensorType = cast(tensorDestination.getType()); assert(generateBody.getNumArguments() == tensorType.getRank() && "rank mismatch"); // Create linalg::MapOp. OpBuilder::InsertionGuard g(rewriter); auto linalgOp = rewriter.create(loc, tensorType, /*inputs=*/ValueRange(), /*init=*/tensorDestination); Block &linalgBody = linalgOp.getMapper().emplaceBlock(); // Create linalg::IndexOps. rewriter.setInsertionPointToStart(&linalgBody); SmallVector indices; for (int64_t dim = 0; dim < tensorType.getRank(); ++dim) indices.push_back(rewriter.create(loc, dim)); // Move over body. rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices); auto yieldOp = cast(linalgBody.getTerminator()); rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); return linalgOp.getResult()[0]; } /// Bufferization of tensor.generate. struct GenerateOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto generateOp = cast(op); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) return op->emitError("memory space not implemented yet"); // Allocate memory. Location loc = op->getLoc(); FailureOr tensorAlloc = allocateTensorForShapedValue( rewriter, loc, generateOp.getResult(), options, /*copy=*/false); if (failed(tensorAlloc)) return failure(); Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc, generateOp.getDynamicExtents(), generateOp.getBody()); rewriter.replaceOp(generateOp, result); return success(); } }; /// Bufferization of tensor.insert. Replace with memref.store. /// /// Note: DstBufferizableOpInterfaceExternalModel provides many default method /// implementations for DestinationStyle ops. struct InsertOpInterface : public DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto insertOp = cast(op); FailureOr destMemref = getBuffer(rewriter, insertOp.getDest(), options); if (failed(destMemref)) return failure(); rewriter.create(insertOp.getLoc(), insertOp.getScalar(), *destMemref, insertOp.getIndices()); replaceOpWithBufferizedValues(rewriter, op, *destMemref); return success(); } }; /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under /// certain circumstances, this op can also be a no-op. /// /// Note: DstBufferizableOpInterfaceExternalModel provides many default method /// implementations for DestinationStyle ops. struct InsertSliceOpInterface : public DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto insertSliceOp = cast(op); RankedTensorType destType = insertSliceOp.getDestType(); // The source is always read. if (&opOperand == &insertSliceOp.getSourceMutable()[0]) return true; // For the destination, it depends... assert(&opOperand == &insertSliceOp.getDestMutable()[0] && "expected dest"); // Dest is not read if it is entirely overwritten. E.g.: // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> bool allOffsetsZero = llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); bool sizesMatchDestSizes = llvm::all_of( llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) { return getConstantIntValue(it.value()) == destType.getDimSize(it.index()); }); bool allStridesOne = llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }); return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { // insert_slice ops arise from tiling and bufferizing them out-of-place is // generally a deal breaker. When used with loops, this ends up cloning the // whole tensor on every single iteration and is a symptom of a // catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); SmallVector mixedSizes = insertSliceOp.getMixedSizes(); SmallVector mixedStrides = insertSliceOp.getMixedStrides(); Location loc = insertSliceOp.getLoc(); // Get destination buffer. FailureOr dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options); if (failed(dstMemref)) return failure(); // Take a subview of the destination buffer. auto dstMemrefType = cast(dstMemref->getType()); auto subviewMemRefType = cast(memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getShape(), dstMemrefType, mixedOffsets, mixedSizes, mixedStrides)); Value subView = rewriter.create( loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, mixedStrides); // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. FailureOr srcMemref = getBuffer(rewriter, insertSliceOp.getSource(), options); if (failed(srcMemref)) return failure(); if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) return failure(); replaceOpWithBufferizedValues(rewriter, op, *dstMemref); return success(); } }; /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor + /// linalg.map + insert_slice. /// For best performance, vectorize before bufferization (better performance in /// case of padding with a constant). struct PadOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToAllocation(Operation *op, Value value) const { return true; } bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { // Infer memory space from the source tensor. auto padOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( padOp.getSource(), options, invocationStack); if (failed(maybeSrcBufferType)) return failure(); MemRefLayoutAttrInterface layout; return MemRefType::get(padOp.getResultType().getShape(), padOp.getResultType().getElementType(), layout, maybeSrcBufferType->getMemorySpace()); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto padOp = cast(op); Location loc = padOp.getLoc(); RankedTensorType resultType = padOp.getResultType(); RankedTensorType srcType = padOp.getSourceType(); auto toValue = [&](OpFoldResult ofr) { if (ofr.is()) return ofr.get(); return rewriter .create(loc, *getConstantIntValue(ofr)) .getResult(); }; // Compute dynamic result dimensions. SmallVector mixedLowPad = padOp.getMixedLowPad(); SmallVector mixedHighPad = padOp.getMixedHighPad(); SmallVector dynamicSizes; for (int64_t i = 0; i < resultType.getRank(); ++i) { if (!resultType.isDynamicDim(i)) continue; Value srcDim = rewriter.create(loc, padOp.getSource(), i); Value lowPad = toValue(mixedLowPad[i]); Value highPad = toValue(mixedHighPad[i]); AffineExpr s0, s1, s2; bindSymbols(op->getContext(), s0, s1, s2); AffineExpr sumExpr = s0 + s1 + s2; Value sum = rewriter.create( loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); dynamicSizes.push_back(sum); } // Allocate a buffer for the padded result. FailureOr tensorAlloc = allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options, /*copy=*/false); if (failed(tensorAlloc)) return failure(); // tensor::PadOp is like tensor::GenerateOp: The only difference is that // only a part of the generated tensor is needed. For simplicity, we reuse // the same functionality here. Value filledBuffer = lowerGenerateLikeOpBody( rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion()); // Create tensor::InsertSliceOp. SmallVector sliceSizes = getMixedSizes(rewriter, loc, padOp.getSource()); SmallVector sliceStrides(srcType.getRank(), rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp( padOp, padOp.getSource(), filledBuffer, /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); return success(); } }; /// Bufferization of tensor.rank. Replace with memref.rank. struct RankOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // The op reads the tensor's metadata but not its contents. return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto rankOp = cast(op); FailureOr v = getBuffer(rewriter, rankOp.getTensor(), options); if (failed(v)) return failure(); replaceOpWithNewBufferizedOp(rewriter, op, rankOp.getType(), *v); return success(); } }; /// Bufferization of tensor.reshape. Replace with memref.reshape. struct ReshapeOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto reshapeOp = cast(op); return &opOperand == &reshapeOp.getShapeMutable()[0]; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {{op->getOpResult(0), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto reshapeOp = cast(op); FailureOr srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options); FailureOr shapeBuffer = getBuffer(rewriter, reshapeOp.getShape(), options); if (failed(srcBuffer) || failed(shapeBuffer)) return failure(); auto maybeResultMemRefType = bufferization::getBufferType(reshapeOp.getResult(), options); if (failed(maybeResultMemRefType)) return failure(); replaceOpWithNewBufferizedOp( rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer); return success(); } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto reshapeOp = cast(op); assert(value == reshapeOp.getResult() && "unexpected value provided"); auto maybeSourceBufferType = bufferization::getBufferType( reshapeOp.getSource(), options, invocationStack); if (failed(maybeSourceBufferType)) return failure(); return getMemRefTypeWithStaticIdentityLayout( reshapeOp.getResult().getType(), cast(maybeSourceBufferType.value()).getMemorySpace()); } }; /// Analysis of ParallelInsertSliceOp. struct ParallelInsertSliceOpInterface : public BufferizableOpInterface::ExternalModel< ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto parallelInsertSliceOp = cast(op); return &opOperand == ¶llelInsertSliceOp.getDestMutable()[0]; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto parallelInsertSliceOp = cast(op); ParallelCombiningOpInterface parallelCombiningParent = parallelInsertSliceOp.getParallelCombiningParent(); // Bufferize the op outside of the parallel combining terminator. rewriter.setInsertionPoint(parallelCombiningParent); // Get source and destination buffers. FailureOr destBuffer = getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); if (failed(destBuffer)) return failure(); FailureOr srcBuffer = getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); if (failed(srcBuffer)) return failure(); // Take a subview of the destination buffer. auto destBufferType = cast(destBuffer->getType()); auto subviewMemRefType = cast(memref::SubViewOp::inferRankReducedResultType( parallelInsertSliceOp.getSourceType().getShape(), destBufferType, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), parallelInsertSliceOp.getMixedStrides())); Value subview = rewriter.create( parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), parallelInsertSliceOp.getMixedStrides()); // This memcpy will fold away if everything bufferizes in-place. if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), *srcBuffer, subview))) return failure(); // In case the source was allocated in the same block, make sure that the // deallocation op (if any) appears after the memcpy. By default, deallocs // are placed before the terminator, but this does not work for ForallOp // because the terminator does more than just yielding a value. // // Note: This is not a problem for the destination buffer because these are // assumed to always bufferize in-place. for (Operation *user : srcBuffer->getUsers()) { if (hasEffect(user)) { if (user->getBlock() == parallelCombiningParent->getBlock()) user->moveBefore(user->getBlock()->getTerminator()); break; } } // Delete the op. rewriter.eraseOp(op); return success(); } }; /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled /// with a linalg.map. Similar to tensor.generate. struct SplatOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto splatOp = cast(op); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) return op->emitError("memory space not implemented yet"); // Allocate memory. Location loc = op->getLoc(); FailureOr tensorAlloc = allocateTensorForShapedValue( rewriter, loc, splatOp.getResult(), options, /*copy=*/false); if (failed(tensorAlloc)) return failure(); // Create linalg::MapOp. auto tensorType = cast(tensorAlloc->getType()); auto linalgOp = rewriter.create(loc, tensorType, /*inputs=*/ValueRange(), /*init=*/*tensorAlloc); Block &linalgBody = linalgOp.getMapper().emplaceBlock(); // Create linalg::IndexOps. rewriter.setInsertionPointToStart(&linalgBody); rewriter.create(loc, splatOp.getInput()); rewriter.replaceOp(splatOp, linalgOp.getResult()[0]); return success(); } }; } // namespace } // namespace tensor } // namespace mlir void mlir::tensor::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { CastOp::attachInterface(*ctx); CollapseShapeOp::attachInterface(*ctx); DimOp::attachInterface(*ctx); EmptyOp::attachInterface(*ctx); ExpandShapeOp::attachInterface(*ctx); ExtractSliceOp::attachInterface(*ctx); ExtractOp::attachInterface(*ctx); FromElementsOp::attachInterface(*ctx); GenerateOp::attachInterface(*ctx); InsertOp::attachInterface(*ctx); InsertSliceOp::attachInterface(*ctx); PadOp::attachInterface(*ctx); ParallelInsertSliceOp::attachInterface( *ctx); RankOp::attachInterface(*ctx); ReshapeOp::attachInterface(*ctx); SplatOp::attachInterface(*ctx); // Load additional dialects of which ops may get created. ctx->loadDialect(); }); // Bufferization requires SubsetInsertionOpInterface models. Make sure that // they are registered. tensor::registerSubsetInsertionOpInterfaceExternalModels(registry); }