[mlir][interfaces][NFC] Move DestinationStyleOpInterface to mlir/Interfaces

This is the second (and final) step of making "destination style" usable without depending on the Linalg dialect. (The first step was D135129.)

This change allows us to provide default bufferization implementations for all destination-style ops. It also allows us to simplify `TilingInterface`. (E.g., `getDestinationOperands` can be removed.)

Differential Revision: https://reviews.llvm.org/D136179
This commit is contained in:
Matthias Springer 2022-10-18 17:23:42 +02:00
parent 44027f3908
commit cfc9ddaafc
17 changed files with 467 additions and 357 deletions

View File

@ -20,6 +20,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"

View File

@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
@ -26,11 +27,6 @@ namespace mlir {
namespace linalg {
class LinalgOp;
/// OpOperand vector that implicitly converts to a Value vector.
struct OpOperandVector : public SmallVector<OpOperand *> {
operator SmallVector<Value>();
};
namespace detail {
/// Implementation of the method that that check if given operands
/// can be dropped, i.e. the remaining operands can compute the loop
@ -57,9 +53,6 @@ LogicalResult verifyFillInterface(Operation *op);
/// Verify that `op` conforms to the invariants of StructuredOpInterface
LogicalResult verifyStructuredOpInterface(Operation *op);
/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface
LogicalResult verifyDestinationStyleOpInterface(Operation *op);
} // namespace detail
} // namespace linalg
} // namespace mlir

View File

@ -879,291 +879,4 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
let verifyWithRegions = 1;
}
// Ops that are in destination style have designated output operands, which act
// as initial tensor values for the results of the operation or the output
// buffers to which the results of the op will be written.
//
// Output operands must be tensors or memrefs. Input operands can have any
// type. All non-output operands are inputs.
// It is assumed that the output operands of the op are the operands at
// position [start, end). The positions are defined by getOutputsPositionRange
// method. All non-output operands are "inputs" of the DPS op.
// If the op has "tensor semantics", then the input operands are either scalars
// or tensors. The output operands are tensors and every tensor output is tied
// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
// tensor is tied to the i-th OpResult. The op may not have any additional
// OpResults. Output operands and their tied OpResults have the same type.
//
// If the op has "buffer semantics", then the input operands are either memrefs
// or other non-tensor types, e.g. scalar types. Furthermore, the output
// operands are memrefs and the op has no results.
//
// Destination-passing style abstraction makes certain transformations easier.
// For example, tiling implementation can extract/insert slices from/into the
// destination of an op and use the resulting shaped value as an iter_arg in
// the surrounding loop structure. As another example, bufferization does not
// have to allocate new buffers for destinations (in case of in-place
// bufferization) and can directly reuse the existing destination buffer.
//
// Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
// where `%t` is the single input and `%d` is the single output. `%d` is tied
// to `%r`.
//
// Example of an op that is not in destination style: `%r = tensor.pad %t`.
// This op is not in destination style because `%r` and `%t` have different
// shape.
//
// Each op that wants to implement DestinationStyleOpInterface needs to define
// the getOutputsPositionRange() method.
def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
let cppNamespace = "::mlir::linalg";
let methods = [
// This method has to be defined for every DPS op.
InterfaceMethod<
/*desc=*/"Return start and end indices of the output operands range.",
/*retTy=*/"std::pair<int64_t, int64_t>",
/*methodName=*/"getOutputsPositionRange",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
//===------------------------------------------------------------------===//
// Operands handling.
//===------------------------------------------------------------------===//
// The operand list is assumed to start with the input operands and end
// with the output operands. Therefore, all methods to access the inputs
// and outputs can be expressed if the number of output operands is know.
InterfaceMethod<
/*desc=*/"Return the number of outputs.",
/*retTy=*/"int64_t",
/*methodName=*/"getNumOutputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
return end - start;
}]
>,
InterfaceMethod<
/*desc=*/"Return the output operands.",
/*retTy=*/"OpOperandVector",
/*methodName=*/"getOutputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
OpOperandVector result;
result.reserve(end - start);
for (int i = start; i < end; ++i)
result.push_back(&$_op->getOpOperand(i));
return result;
}]
>,
InterfaceMethod<
/*desc=*/"Return the `i`-th output operand.",
/*retTy=*/"OpOperand*",
/*methodName=*/"getOutputOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i >= 0 && i < $_op.getNumOutputs());
auto [start, end] = $_op.getOutputsPositionRange();
return &$_op->getOpOperand(start + i);
}]
>,
InterfaceMethod<
/*desc=*/"Set the `i`-th output operand.",
/*retTy=*/"void",
/*methodName=*/"setOutputOperand",
/*args=*/(ins "int64_t":$i, "Value":$value),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i >= 0 && i < $_op.getNumOutputs());
auto [start, end] = $_op.getOutputsPositionRange();
$_op->setOperand(start + i, value);
}]
>,
InterfaceMethod<
/*desc=*/"Return the number of inputs.",
/*retTy=*/"int64_t",
/*methodName=*/"getNumInputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getNumOperands() - $_op.getNumOutputs();
}]
>,
InterfaceMethod<
/*desc=*/"Return the input operands.",
/*retTy=*/"OpOperandVector",
/*methodName=*/"getInputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
int64_t numOutputs = end - start;
int64_t numOperands = $_op.getNumOperands();
OpOperandVector result;
result.reserve(numOperands - numOutputs);
for (int i = 0; i < start; ++i)
result.push_back(&$_op->getOpOperand(i));
for (int i = end; i < numOperands; ++i)
result.push_back(&$_op->getOpOperand(end + i));
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{ Return the `i`-th input operand. }],
/*retTy=*/"OpOperand*",
/*methodName=*/"getInputOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i >= 0 && i < getNumInputs());
auto [start, end] = $_op.getOutputsPositionRange();
return &$_op->getOpOperand(i < start ? i : i + end - start) ;
}]
>,
//===------------------------------------------------------------------===//
// Input and Output arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/"Return true if `opOperand` is an input.",
/*retTy=*/"bool",
/*methodName=*/"isInput",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
auto operandNumber = opOperand->getOperandNumber();
return operandNumber < start || operandNumber >= end;
}]
>,
InterfaceMethod<
/*desc=*/"Return true if `opOperand` is an output.",
/*retTy=*/"bool",
/*methodName=*/"isOutput",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
auto operandNumber = opOperand->getOperandNumber();
return operandNumber >= start && operandNumber < end;
}]
>,
InterfaceMethod<
/*desc=*/"Return true if the `opOperand` is a scalar value.",
/*retTy=*/"bool",
/*methodName=*/"isScalar",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
return !opOperand->get().getType().template isa<ShapedType>();
}]
>,
InterfaceMethod<
/*desc=*/"Return the result tied to `opOperand`.",
/*retTy=*/"OpResult",
/*methodName=*/"getTiedOpResult",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
auto [start, end] = $_op.getOutputsPositionRange();
int64_t resultIndex = opOperand->getOperandNumber() - start;
assert(resultIndex >= 0 &&
resultIndex < $_op->getNumResults() );
return $_op->getResult(resultIndex);
}]
>,
//===------------------------------------------------------------------===//
// Other interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/"Return whether the op has only MemRef input and outputs.",
/*retTy=*/"bool",
/*methodName=*/"hasBufferSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op->getNumResults() == 0 &&
llvm::all_of($_op->getOpOperands(),
[&](OpOperand &opOperand) {
return isScalar(&opOperand) ||
opOperand.get().getType().template isa<MemRefType>();
});
}]
>,
InterfaceMethod<
/*desc=*/"Return whether the op has only RankedTensor input and outputs.",
/*retTy=*/"bool",
/*methodName=*/"hasTensorSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::all_of($_op->getOpOperands(),
[&](OpOperand &opOperand) {
return isScalar(&opOperand) ||
opOperand.get().getType().template isa<RankedTensorType>();
});
}]
>,
//===------------------------------------------------------------------===//
// Other static interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location and operands. This
is used to abstract away the optional underlying region creation. This
does not change the balance between input, output_buffer and
init_tensors operands.
}],
/*retTy=*/"Operation *",
/*methodName=*/"clone",
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
BlockAndValueMapping bvm;
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());
for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm);
return b.create(state);
}]
>,
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location, operands
and BlockAndValueMapping but leave the regions empty. This is
used to abstract away the optional underlying region creation.
This does not change the balance between input, output_buffer
and init_tensors operands.
}],
/*retTy=*/"Operation *",
/*methodName=*/"cloneWithoutRegions",
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());
for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
state.addRegion();
return b.create(state);
}]
>
];
let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];
let verifyWithRegions = 1;
}
#endif // LINALG_IR_LINALGINTERFACES

View File

@ -17,6 +17,7 @@
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
@ -279,7 +280,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
linalg::OpOperandVector getOpOperandsMatchingBBargs() {
OpOperandVector getOpOperandsMatchingBBargs() {
return getInputOperands();
}

View File

@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)

View File

@ -0,0 +1,34 @@
//===- DestinationStyleOpInterface.h ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_
#define MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
/// OpOperand vector that implicitly converts to a Value vector.
struct OpOperandVector : public llvm::SmallVector<OpOperand *> {
operator SmallVector<Value>();
};
namespace detail {
/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface
LogicalResult verifyDestinationStyleOpInterface(Operation *op);
} // namespace detail
} // namespace mlir
/// Include the generated interface declarations.
#include "mlir/Interfaces/DestinationStyleOpInterface.h.inc"
#endif // MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_

View File

@ -0,0 +1,306 @@
//===- DestinationStyleOpInterface.td ----------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DESTINATIONSTYLEOPINTERFACE
#define MLIR_DESTINATIONSTYLEOPINTERFACE
include "mlir/IR/OpBase.td"
def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
let description = [{
Ops that are in destination style have designated output operands, which act
as initial tensor values for the results of the operation or the output
buffers to which the results of the op will be written.
Output operands must be tensors or memrefs. Input operands can have any
type. All non-output operands are inputs.
It is assumed that the output operands of the op are the operands at
position [start, end). The positions are defined by getOutputsPositionRange
method. All non-output operands are "inputs" of the DPS op.
If the op has "tensor semantics", then the input operands are either scalars
or tensors. The output operands are tensors and every tensor output is tied
to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
tensor is tied to the i-th OpResult. The op may not have any additional
OpResults. Output operands and their tied OpResults have the same type.
If the op has "buffer semantics", then the input operands are either memrefs
or other non-tensor types, e.g. scalar types. Furthermore, the output
operands are memrefs and the op has no results.
Destination-passing style abstraction makes certain transformations easier.
For example, tiling implementation can extract/insert slices from/into the
destination of an op and use the resulting shaped value as an iter_arg in
the surrounding loop structure. As another example, bufferization does not
have to allocate new buffers for destinations (in case of in-place
bufferization) and can directly reuse the existing destination buffer.
Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
where `%t` is the single input and `%d` is the single output. `%d` is tied
to `%r`.
Example of an op that is not in destination style: `%r = tensor.pad %t`.
This op is not in destination style because `%r` and `%t` have different
shape.
Each op that wants to implement DestinationStyleOpInterface needs to define
the getOutputsPositionRange() method.
}];
let cppNamespace = "::mlir";
let methods = [
// This method has to be defined for every DPS op.
InterfaceMethod<
/*desc=*/"Return start and end indices of the output operands range.",
/*retTy=*/"std::pair<int64_t, int64_t>",
/*methodName=*/"getOutputsPositionRange",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
//===------------------------------------------------------------------===//
// Operands handling.
//===------------------------------------------------------------------===//
// The operand list is assumed to start with the input operands and end
// with the output operands. Therefore, all methods to access the inputs
// and outputs can be expressed if the number of output operands is know.
InterfaceMethod<
/*desc=*/"Return the number of outputs.",
/*retTy=*/"int64_t",
/*methodName=*/"getNumOutputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
return end - start;
}]
>,
InterfaceMethod<
/*desc=*/"Return the output operands.",
/*retTy=*/"OpOperandVector",
/*methodName=*/"getOutputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
OpOperandVector result;
result.reserve(end - start);
for (int i = start; i < end; ++i)
result.push_back(&$_op->getOpOperand(i));
return result;
}]
>,
InterfaceMethod<
/*desc=*/"Return the `i`-th output operand.",
/*retTy=*/"OpOperand *",
/*methodName=*/"getOutputOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i >= 0 && i < $_op.getNumOutputs());
auto [start, end] = $_op.getOutputsPositionRange();
return &$_op->getOpOperand(start + i);
}]
>,
InterfaceMethod<
/*desc=*/"Set the `i`-th output operand.",
/*retTy=*/"void",
/*methodName=*/"setOutputOperand",
/*args=*/(ins "int64_t":$i, "Value":$value),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i >= 0 && i < $_op.getNumOutputs());
auto [start, end] = $_op.getOutputsPositionRange();
$_op->setOperand(start + i, value);
}]
>,
InterfaceMethod<
/*desc=*/"Return the number of inputs.",
/*retTy=*/"int64_t",
/*methodName=*/"getNumInputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getNumOperands() - $_op.getNumOutputs();
}]
>,
InterfaceMethod<
/*desc=*/"Return the input operands.",
/*retTy=*/"OpOperandVector",
/*methodName=*/"getInputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
int64_t numOutputs = end - start;
int64_t numOperands = $_op.getNumOperands();
OpOperandVector result;
result.reserve(numOperands - numOutputs);
for (int i = 0; i < start; ++i)
result.push_back(&$_op->getOpOperand(i));
for (int i = end; i < numOperands; ++i)
result.push_back(&$_op->getOpOperand(end + i));
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{ Return the `i`-th input operand. }],
/*retTy=*/"OpOperand *",
/*methodName=*/"getInputOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i >= 0 && i < getNumInputs());
auto [start, end] = $_op.getOutputsPositionRange();
return &$_op->getOpOperand(i < start ? i : i + end - start) ;
}]
>,
//===------------------------------------------------------------------===//
// Input and Output arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/"Return true if `opOperand` is an input.",
/*retTy=*/"bool",
/*methodName=*/"isInput",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
auto operandNumber = opOperand->getOperandNumber();
return operandNumber < start || operandNumber >= end;
}]
>,
InterfaceMethod<
/*desc=*/"Return true if `opOperand` is an output.",
/*retTy=*/"bool",
/*methodName=*/"isOutput",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto [start, end] = $_op.getOutputsPositionRange();
auto operandNumber = opOperand->getOperandNumber();
return operandNumber >= start && operandNumber < end;
}]
>,
InterfaceMethod<
/*desc=*/"Return true if the `opOperand` is a scalar value.",
/*retTy=*/"bool",
/*methodName=*/"isScalar",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
return !opOperand->get().getType().template isa<ShapedType>();
}]
>,
InterfaceMethod<
/*desc=*/"Return the result tied to `opOperand`.",
/*retTy=*/"OpResult",
/*methodName=*/"getTiedOpResult",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
auto [start, end] = $_op.getOutputsPositionRange();
int64_t resultIndex = opOperand->getOperandNumber() - start;
assert(resultIndex >= 0 &&
resultIndex < $_op->getNumResults() );
return $_op->getResult(resultIndex);
}]
>,
//===------------------------------------------------------------------===//
// Other interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/"Return whether the op has only MemRef input and outputs.",
/*retTy=*/"bool",
/*methodName=*/"hasBufferSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op->getNumResults() == 0 &&
llvm::all_of($_op->getOpOperands(),
[&](OpOperand &opOperand) {
return isScalar(&opOperand) ||
opOperand.get().getType().template isa<MemRefType>();
});
}]
>,
InterfaceMethod<
/*desc=*/"Return whether the op has only RankedTensor input and outputs.",
/*retTy=*/"bool",
/*methodName=*/"hasTensorSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::all_of($_op->getOpOperands(),
[&](OpOperand &opOperand) {
return isScalar(&opOperand) ||
opOperand.get().getType().template isa<RankedTensorType>();
});
}]
>,
//===------------------------------------------------------------------===//
// Other static interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location and operands. This
is used to abstract away the optional underlying region creation. This
does not change the balance between input, output_buffer and
init_tensors operands.
}],
/*retTy=*/"Operation *",
/*methodName=*/"clone",
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
BlockAndValueMapping bvm;
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());
for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm);
return b.create(state);
}]
>,
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location, operands
and BlockAndValueMapping but leave the regions empty. This is
used to abstract away the optional underlying region creation.
This does not change the balance between input, output_buffer
and init_tensors operands.
}],
/*retTy=*/"Operation *",
/*methodName=*/"cloneWithoutRegions",
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());
for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
state.addRegion();
return b.create(state);
}]
>
];
let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];
let verifyWithRegions = 1;
}
#endif // MLIR_DESTINATIONSTYLEOPINTERFACE

View File

@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
MLIRArithDialect
MLIRArithUtils
MLIRBufferizationDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRInferTypeOpInterface
MLIRIR

View File

@ -462,14 +462,6 @@ LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
// StructuredOpInterface implementation
//===----------------------------------------------------------------------===//
OpOperandVector::operator SmallVector<Value>() {
SmallVector<Value> result;
result.reserve(this->size());
llvm::transform(*this, std::back_inserter(result),
[](OpOperand *opOperand) { return opOperand->get(); });
return result;
}
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
@ -770,55 +762,3 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
return success();
}
LogicalResult
mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
for (OpOperand *operand : dstStyleOp.getOutputOperands()) {
Type type = operand->get().getType();
if (type.isa<MemRefType>())
outputBufferOperands.push_back(operand);
if (type.isa<RankedTensorType>())
outputTensorOperands.push_back(operand);
}
// Expect at least one output operand.
// This means an op that constructs a tensor out of indices cannot be a
// LinalgOp at the moment. For now this will have to be a special op until we
// have output shape operands that are not tensors.
int64_t numInputs = dstStyleOp.getNumInputs();
int64_t numOutputs = dstStyleOp.getNumOutputs();
if (numOutputs == 0)
return op->emitOpError("expected at least one output operand");
if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
return failure();
// Verify the number of results matches the number of output tensors.
if (op->getNumResults() != outputTensorOperands.size())
return op->emitOpError("expected the number of results (")
<< op->getNumResults()
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";
// Simplifying assumption: either full tensor or full buffer mode.
// This allows simpler verification of output operands vs result types
// without premature tracking of which operand is what in mixed-mode.
// TODO: relax when mixed-mode needs to pass verification.
if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
return op->emitOpError(
"expected output operands to all have tensor type or "
"all have buffer type");
for (OpOperand *opOperand : outputTensorOperands) {
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
if (result.getType() != opOperand->get().getType())
return op->emitOpError("expected type of operand #")
<< opOperand->getOperandNumber() << " ("
<< opOperand->get().getType() << ")"
<< " to match type of corresponding result (" << result.getType()
<< ")";
}
return success();
}

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
using namespace mlir;
using namespace linalg;
@ -115,7 +116,7 @@ struct LinalgOpInterface
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto genericOp = cast<linalg::DestinationStyleOpInterface>(op);
auto genericOp = cast<DestinationStyleOpInterface>(op);
// The i-th "out" tensor may alias with the i-th OpResult.
if (genericOp.isOutput(&opOperand))

View File

@ -43,6 +43,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRComplexDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRFuncDialect
MLIRFuncToLLVM

View File

@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
CopyOpInterface.cpp
DataLayoutInterfaces.cpp
DerivedAttributeOpInterface.cpp
DestinationStyleOpInterface.cpp
InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
@ -38,6 +39,7 @@ add_mlir_interface_library(ControlFlowInterfaces)
add_mlir_interface_library(CopyOpInterface)
add_mlir_interface_library(DataLayoutInterfaces)
add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(DestinationStyleOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(LoopLikeInterface)

View File

@ -0,0 +1,71 @@
//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
using namespace mlir;
namespace mlir {
#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
} // namespace mlir
OpOperandVector::operator SmallVector<Value>() {
SmallVector<Value> result;
result.reserve(this->size());
llvm::transform(*this, std::back_inserter(result),
[](OpOperand *opOperand) { return opOperand->get(); });
return result;
}
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
for (OpOperand *operand : dstStyleOp.getOutputOperands()) {
Type type = operand->get().getType();
if (type.isa<MemRefType>())
outputBufferOperands.push_back(operand);
if (type.isa<RankedTensorType>())
outputTensorOperands.push_back(operand);
}
// Expect at least one output operand.
int64_t numInputs = dstStyleOp.getNumInputs();
int64_t numOutputs = dstStyleOp.getNumOutputs();
if (numOutputs == 0)
return op->emitOpError("expected at least one output operand");
if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
return failure();
// Verify the number of results matches the number of output tensors.
if (op->getNumResults() != outputTensorOperands.size())
return op->emitOpError("expected the number of results (")
<< op->getNumResults()
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";
// Simplifying assumption: either full tensor or full buffer mode.
// This allows simpler verification of output operands vs result types
// without premature tracking of which operand is what in mixed-mode.
// TODO: relax when mixed-mode needs to pass verification.
if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
return op->emitOpError(
"expected output operands to all have tensor type or "
"all have buffer type");
for (OpOperand *opOperand : outputTensorOperands) {
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
if (result.getType() != opOperand->get().getType())
return op->emitOpError("expected type of operand #")
<< opOperand->getOperandNumber() << " ("
<< opOperand->get().getType() << ")"
<< " to match type of corresponding result (" << result.getType()
<< ")";
}
return success();
}

View File

@ -54,6 +54,7 @@ add_mlir_library(MLIRTestDialect
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRDerivedAttributeOpInterface
MLIRDestinationStyleOpInterface
MLIRDialect
MLIRDLTIDialect
MLIRFuncDialect

View File

@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"

View File

@ -995,6 +995,13 @@ td_library(
deps = [":OpBaseTdFiles"],
)
td_library(
name = "DestinationStyleOpInterfaceTdFiles",
srcs = ["include/mlir/Interfaces/DestinationStyleOpInterface.td"],
includes = ["include"],
deps = [":OpBaseTdFiles"],
)
td_library(
name = "InferIntRangeInterfaceTdFiles",
srcs = ["include/mlir/Interfaces/InferIntRangeInterface.td"],
@ -5321,6 +5328,36 @@ cc_library(
],
)
gentbl_cc_library(
name = "DestinationStyleOpInterfaceIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-interface-decls"],
"include/mlir/Interfaces/DestinationStyleOpInterface.h.inc",
),
(
["-gen-op-interface-defs"],
"include/mlir/Interfaces/DestinationStyleOpInterface.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Interfaces/DestinationStyleOpInterface.td",
deps = [":DestinationStyleOpInterfaceTdFiles"],
)
cc_library(
name = "DestinationStyleOpInterface",
srcs = ["lib/Interfaces/DestinationStyleOpInterface.cpp"],
hdrs = ["include/mlir/Interfaces/DestinationStyleOpInterface.h"],
includes = ["include"],
deps = [
":DestinationStyleOpInterfaceIncGen",
":IR",
"//llvm:Support",
],
)
gentbl_cc_library(
name = "InferIntRangeInterfaceIncGen",
strip_include_prefix = "include",
@ -7437,6 +7474,7 @@ td_library(
includes = ["include"],
deps = [
":ControlFlowInterfacesTdFiles",
":DestinationStyleOpInterfaceTdFiles",
":DialectUtilsTdFiles",
":InferTypeOpInterfaceTdFiles",
":LoopLikeInterfaceTdFiles",
@ -7571,6 +7609,7 @@ td_library(
includes = ["include"],
deps = [
":CopyOpInterfaceTdFiles",
":DestinationStyleOpInterface",
":LinalgOpsTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
@ -7768,6 +7807,7 @@ cc_library(
":ComplexDialect",
":ControlFlowInterfaces",
":CopyOpInterface",
":DestinationStyleOpInterface",
":DialectUtils",
":FuncDialect",
":IR",
@ -7925,6 +7965,7 @@ cc_library(
":BufferizationTransforms",
":ComplexDialect",
":ControlFlowDialect",
":DestinationStyleOpInterface",
":DialectUtils",
":FuncDialect",
":FuncTransforms",

View File

@ -94,6 +94,7 @@ td_library(
"//mlir:CopyOpInterfaceTdFiles",
"//mlir:DLTIDialectTdFiles",
"//mlir:DataLayoutInterfacesTdFiles",
"//mlir:DestinationStyleOpInterfaceTdFiles",
"//mlir:InferIntRangeInterfaceTdFiles",
"//mlir:InferTypeOpInterfaceTdFiles",
"//mlir:LinalgStructuredOpsTdFiles",
@ -325,6 +326,7 @@ cc_library(
"//mlir:DLTIDialect",
"//mlir:DataLayoutInterfaces",
"//mlir:DerivedAttributeOpInterface",
"//mlir:DestinationStyleOpInterface",
"//mlir:Dialect",
"//mlir:FuncDialect",
"//mlir:FuncTransforms",