[mlir] Add TransposeOp to Linalg structured ops.

RFC: https://discourse.llvm.org/t/rfc-primitive-ops-add-mapop-reductionop-transposeop-broadcastop-to-linalg/64184

Differential Revision: https://reviews.llvm.org/D135854
This commit is contained in:
Oleg Shyshkov 2022-10-19 11:42:25 +02:00
parent 1625224fbb
commit d261aa88f8
7 changed files with 300 additions and 19 deletions

View File

@ -70,6 +70,10 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b);
/// Check if `permutation` is a permutation of the range
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
} // namespace linalg
} // namespace mlir

View File

@ -360,6 +360,78 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
}
//===----------------------------------------------------------------------===//
// Transpose op.
//===----------------------------------------------------------------------===//
def TransposeOp : LinalgStructuredBase_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
SameVariadicOperandSize,
SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "Transpose operator";
let description = [{
Permutes the dimensions of `input` according to the given `permutation`.
`dim(result, i) = dim(input, permutation[i])`
This op actually moves data, unlike `memref.transpose` which is a metadata
operation only that produces a transposed "view".
Example:
```
%transpose = linalg.transpose
ins(%input:tensor<16x64xf32>)
outs(%init:tensor<64x16xf32>)
permutation = [1, 0]
```
}];
let arguments = (ins
// Input arg
TensorOrMemref:$input,
// Output arg
TensorOrMemref:$init,
DenseI64ArrayAttr:$permutation
);
let results = (outs Variadic<AnyTensor>:$result);
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$input, "Value":$init,
"DenseI64ArrayAttr":$permutation, CArg<"ArrayRef<NamedAttribute>",
"{}">:$attributes)>,
OpBuilder<(ins "Value":$input, "Value":$init,
"ArrayRef<int64_t>":$permutation, CArg<"ArrayRef<NamedAttribute>",
"{}">:$attributes)>,
];
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
}
// Implement functions necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getOutputsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder();
static void createRegion(::mlir::OpBuilder &opBuilder,
::mlir::OperationState & odsState);
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//

View File

@ -41,10 +41,6 @@ bool hasOnlyScalarElementwiseOp(Region &r);
/// Check if a LinalgOp is an element-wise operation.
bool isElementwise(LinalgOp op);
/// Check if `permutation` is a permutation of the range
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
/// Check if iterator type has "parallel" semantics.
bool isParallelIterator(StringRef iteratorType);

View File

@ -1601,6 +1601,142 @@ LogicalResult ReduceOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
TransposeOp::getRegionBuilder() {
return [](mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
mlir::ArrayRef<mlir::NamedAttribute>) {
b.create<linalg::YieldOp>(block.getArguments().back());
};
}
void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
::mlir::OperationState &odsState) {
Region *region = odsState.addRegion();
SmallVector<Type> argTypes;
SmallVector<Location> argLocs;
for (auto t : odsState.operands) {
argTypes.push_back(getElementTypeOrSelf(t));
argLocs.push_back(opBuilder.getUnknownLoc());
}
// RAII.
OpBuilder::InsertionGuard guard(opBuilder);
Block *body =
opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
}
void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, Value input,
Value init, DenseI64ArrayAttr permutation,
ArrayRef<NamedAttribute> attributes) {
odsState.addOperands(input);
odsState.addOperands(init);
odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
odsState.addAttributes(attributes);
odsState.addTypes(init.getType());
createRegion(odsBuilder, odsState);
}
void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, Value input,
Value init, ArrayRef<int64_t> permutation,
ArrayRef<NamedAttribute> attributes) {
build(odsBuilder, odsState, input, init,
odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
}
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
if (failed(parseDstStyleOp(
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
return parseDenseI64ArrayAttr(parser, attributes, "permutation");
})))
return failure();
OpBuilder opBuilder(parser.getContext());
createRegion(opBuilder, result);
return success();
}
void TransposeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
if (!getResults().empty())
setNameFn(getResults().front(), "transposed");
}
void TransposeOp::print(OpAsmPrinter &p) {
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
SmallVector<Value>(getOutputOperands()));
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
}
LogicalResult TransposeOp::verify() {
ArrayRef<int64_t> permutationRef = getPermutation();
if (!isPermutation(permutationRef))
return emitOpError("permutation is not valid");
auto inputType = getInput().getType();
auto initType = getInit().getType();
int64_t rank = inputType.getRank();
if (rank != initType.getRank())
return emitOpError() << "input rank " << rank
<< " does not match init rank " << initType.getRank();
if (rank != static_cast<int64_t>(permutationRef.size()))
return emitOpError() << "size of permutation " << permutationRef.size()
<< " does not match the argument rank " << rank;
auto inputDims = inputType.getShape();
auto initDims = initType.getShape();
for (int64_t i = 0; i < rank; ++i) {
int64_t inputDim = inputDims[permutationRef[i]];
int64_t initDim = initDims[i];
if (inputDim != initDim) {
return emitOpError() << "dim(result, " << i << ") = " << initDim
<< " doesn't match dim(input, permutation[" << i
<< "]) = " << inputDim;
}
}
return success();
}
SmallVector<StringRef> TransposeOp::getIteratorTypesArray() {
int64_t rank = getInit().getType().getRank();
return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
}
ArrayAttr TransposeOp::getIndexingMaps() {
Builder builder(getContext());
int64_t rank = getInit().getType().getRank();
return builder.getAffineMapArrayAttr(
{builder.getMultiDimIdentityMap(rank),
AffineMap::getPermutationMap(
llvm::to_vector_of<unsigned>(getPermutation()), getContext())});
}
void TransposeOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getGenericEffectsImpl(effects, getOperation()->getResults(),
getInputOperands(), getOutputOperands());
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
@ -1710,6 +1846,19 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
return llvm::to_vector<4>(concatRanges);
}
bool mlir::linalg::isPermutation(ArrayRef<int64_t> permutation) {
// Count the number of appearances for all indices.
SmallVector<int64_t> indexCounts(permutation.size(), 0);
for (auto index : permutation) {
// Exit if the index is out-of-range.
if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
return false;
++indexCounts[index];
}
// Return true if all indices appear once.
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
}
static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
if (auto memref = t.dyn_cast<MemRefType>()) {
ss << "view";

View File

@ -186,19 +186,6 @@ bool isElementwise(LinalgOp op) {
return hasOnlyScalarElementwiseOp(op->getRegion(0));
}
bool isPermutation(ArrayRef<int64_t> permutation) {
// Count the number of appearances for all indices.
SmallVector<int64_t> indexCounts(permutation.size(), 0);
for (auto index : permutation) {
// Exit if the index is out-of-range.
if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
return false;
indexCounts[index]++;
}
// Return true if all indices appear once.
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
}
bool isParallelIterator(StringRef iteratorType) {
return iteratorType == getParallelIteratorTypeName();
}

View File

@ -624,3 +624,52 @@ func.func @reduce_different_output_shapes(%input1: tensor<16x32x64xf32>,
}
func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32>
}
// -----
func.func @transpose_invalid_permutation(%input: tensor<16x32x64xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
// expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
%transpose = linalg.transpose
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<32x64x16xf32>)
permutation = [1, 1, 2]
func.return %transpose : tensor<32x64x16xf32>
}
// -----
func.func @transpose_permutated_dims_mismatch(%input: tensor<16x32x64xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
// expected-error @+1 {{'linalg.transpose' op dim(result, 0) = 32 doesn't match dim(input, permutation[0]) = 16}}
%transpose = linalg.transpose
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<32x64x16xf32>)
permutation = [0, 1, 2]
func.return %transpose : tensor<32x64x16xf32>
}
// -----
func.func @transpose_rank_permutation_size_mismatch(
%input: tensor<16x32x64xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
// expected-error @+1 {{'linalg.transpose' op size of permutation 2 does not match the argument rank 3}}
%transpose = linalg.transpose
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<32x64x16xf32>)
permutation = [1, 0]
func.return %transpose : tensor<32x64x16xf32>
}
// -----
func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
// expected-error @+1 {{'linalg.transpose' op input rank 2 does not match init rank 3}}
%transpose = linalg.transpose
ins(%input:tensor<16x32xf32>)
outs(%init:tensor<32x64x16xf32>)
permutation = [1, 0, 2]
func.return %transpose : tensor<32x64x16xf32>
}

View File

@ -67,11 +67,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
// -----
func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
%0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
return
}
// CHECK-LABEL: func @transpose
// CHECK-LABEL: func @memref_transpose
// CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
@ -457,3 +457,27 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
}
// CHECK-LABEL: func @variadic_reduce_memref
// CHECK: linalg.reduce
// -----
func.func @transpose(%input: tensor<16x32x64xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
%transpose = linalg.transpose
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<32x64x16xf32>)
permutation = [1, 2, 0]
func.return %transpose : tensor<32x64x16xf32>
}
// CHECK-LABEL: func @transpose
// -----
func.func @transpose_memref(%input: memref<16x32x64xf32>,
%init: memref<32x64x16xf32>) {
linalg.transpose
ins(%input:memref<16x32x64xf32>)
outs(%init:memref<32x64x16xf32>)
permutation = [1, 2, 0]
func.return
}
// CHECK-LABEL: func @transpose_memref