diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h index 28c75fcfa653..e231bddfcc41 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -70,6 +70,10 @@ AffineMap extractOrIdentityMap(Optional maybeMap, unsigned rank, SmallVector concat(ArrayRef a, ArrayRef b); +/// Check if `permutation` is a permutation of the range +/// `[0, permutation.size())`. +bool isPermutation(ArrayRef permutation); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 4b83de12a410..9c2246e74dca 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -360,6 +360,78 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ } +//===----------------------------------------------------------------------===// +// Transpose op. +//===----------------------------------------------------------------------===// + +def TransposeOp : LinalgStructuredBase_Op<"transpose", [ + DeclareOpInterfaceMethods, + SameVariadicOperandSize, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Transpose operator"; + let description = [{ + Permutes the dimensions of `input` according to the given `permutation`. + `dim(result, i) = dim(input, permutation[i])` + + This op actually moves data, unlike `memref.transpose` which is a metadata + operation only that produces a transposed "view". + + Example: + ``` + %transpose = linalg.transpose + ins(%input:tensor<16x64xf32>) + outs(%init:tensor<64x16xf32>) + permutation = [1, 0] + ``` + }]; + + let arguments = (ins + // Input arg + TensorOrMemref:$input, + // Output arg + TensorOrMemref:$init, + + DenseI64ArrayAttr:$permutation + ); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "Value":$input, "Value":$init, + "DenseI64ArrayAttr":$permutation, CArg<"ArrayRef", + "{}">:$attributes)>, + OpBuilder<(ins "Value":$input, "Value":$init, + "ArrayRef":$permutation, CArg<"ArrayRef", + "{}">:$attributes)>, + ]; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + // Declare functions necessary for LinalgStructuredInterface. + SmallVector getIteratorTypesArray(); + ArrayAttr getIndexingMaps(); + std::string getLibraryCallName() { + return "op_has_no_registered_library_name"; + } + + // Implement functions necessary for DestinationStyleOpInterface. + std::pair getOutputsPositionRange() { + int64_t getNumOperands = this->getNumOperands(); + return {getNumOperands - 1, getNumOperands}; + } + + static std::function)> + getRegionBuilder(); + + static void createRegion(::mlir::OpBuilder &opBuilder, + ::mlir::OperationState & odsState); + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 305b859ac13d..6a10d4332e7e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -41,10 +41,6 @@ bool hasOnlyScalarElementwiseOp(Region &r); /// Check if a LinalgOp is an element-wise operation. bool isElementwise(LinalgOp op); -/// Check if `permutation` is a permutation of the range -/// `[0, permutation.size())`. -bool isPermutation(ArrayRef permutation); - /// Check if iterator type has "parallel" semantics. bool isParallelIterator(StringRef iteratorType); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 2fcd21cb59f9..82e5024cf58b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1601,6 +1601,142 @@ LogicalResult ReduceOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +std::function)> +TransposeOp::getRegionBuilder() { + return [](mlir::ImplicitLocOpBuilder &b, mlir::Block &block, + mlir::ArrayRef) { + b.create(block.getArguments().back()); + }; +} + +void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder, + ::mlir::OperationState &odsState) { + Region *region = odsState.addRegion(); + + SmallVector argTypes; + SmallVector argLocs; + for (auto t : odsState.operands) { + argTypes.push_back(getElementTypeOrSelf(t)); + argLocs.push_back(opBuilder.getUnknownLoc()); + } + + // RAII. + OpBuilder::InsertionGuard guard(opBuilder); + Block *body = + opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs); + + ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); + getRegionBuilder()(b, *body, odsState.attributes.getAttrs()); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, Value input, + Value init, DenseI64ArrayAttr permutation, + ArrayRef attributes) { + odsState.addOperands(input); + odsState.addOperands(init); + odsState.addAttribute(getPermutationAttrName(odsState.name), permutation); + odsState.addAttributes(attributes); + odsState.addTypes(init.getType()); + + createRegion(odsBuilder, odsState); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, Value input, + Value init, ArrayRef permutation, + ArrayRef attributes) { + build(odsBuilder, odsState, input, init, + odsBuilder.getDenseI64ArrayAttr(permutation), attributes); +} + +ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { + if (failed(parseDstStyleOp( + parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + return parseDenseI64ArrayAttr(parser, attributes, "permutation"); + }))) + return failure(); + + OpBuilder opBuilder(parser.getContext()); + createRegion(opBuilder, result); + return success(); +} + +void TransposeOp::getAsmResultNames( + function_ref setNameFn) { + if (!getResults().empty()) + setNameFn(getResults().front(), "transposed"); +} + +void TransposeOp::print(OpAsmPrinter &p) { + printCommonStructuredOpParts(p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); + printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); + p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); +} + +LogicalResult TransposeOp::verify() { + ArrayRef permutationRef = getPermutation(); + + if (!isPermutation(permutationRef)) + return emitOpError("permutation is not valid"); + + auto inputType = getInput().getType(); + auto initType = getInit().getType(); + + int64_t rank = inputType.getRank(); + + if (rank != initType.getRank()) + return emitOpError() << "input rank " << rank + << " does not match init rank " << initType.getRank(); + + if (rank != static_cast(permutationRef.size())) + return emitOpError() << "size of permutation " << permutationRef.size() + << " does not match the argument rank " << rank; + + auto inputDims = inputType.getShape(); + auto initDims = initType.getShape(); + + for (int64_t i = 0; i < rank; ++i) { + int64_t inputDim = inputDims[permutationRef[i]]; + int64_t initDim = initDims[i]; + + if (inputDim != initDim) { + return emitOpError() << "dim(result, " << i << ") = " << initDim + << " doesn't match dim(input, permutation[" << i + << "]) = " << inputDim; + } + } + + return success(); +} + +SmallVector TransposeOp::getIteratorTypesArray() { + int64_t rank = getInit().getType().getRank(); + return SmallVector(rank, getParallelIteratorTypeName()); +} + +ArrayAttr TransposeOp::getIndexingMaps() { + Builder builder(getContext()); + int64_t rank = getInit().getType().getRank(); + return builder.getAffineMapArrayAttr( + {builder.getMultiDimIdentityMap(rank), + AffineMap::getPermutationMap( + llvm::to_vector_of(getPermutation()), getContext())}); +} + +void TransposeOp::getEffects( + SmallVectorImpl> + &effects) { + getGenericEffectsImpl(effects, getOperation()->getResults(), + getInputOperands(), getOutputOperands()); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// @@ -1710,6 +1846,19 @@ SmallVector mlir::linalg::concat(ArrayRef a, return llvm::to_vector<4>(concatRanges); } +bool mlir::linalg::isPermutation(ArrayRef permutation) { + // Count the number of appearances for all indices. + SmallVector indexCounts(permutation.size(), 0); + for (auto index : permutation) { + // Exit if the index is out-of-range. + if (index < 0 || index >= static_cast(permutation.size())) + return false; + ++indexCounts[index]; + } + // Return true if all indices appear once. + return count(indexCounts, 1) == static_cast(permutation.size()); +} + static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { if (auto memref = t.dyn_cast()) { ss << "view"; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index aba2d5f5cd49..af5a2012429b 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -186,19 +186,6 @@ bool isElementwise(LinalgOp op) { return hasOnlyScalarElementwiseOp(op->getRegion(0)); } -bool isPermutation(ArrayRef permutation) { - // Count the number of appearances for all indices. - SmallVector indexCounts(permutation.size(), 0); - for (auto index : permutation) { - // Exit if the index is out-of-range. - if (index < 0 || index >= static_cast(permutation.size())) - return false; - indexCounts[index]++; - } - // Return true if all indices appear once. - return count(indexCounts, 1) == static_cast(permutation.size()); -} - bool isParallelIterator(StringRef iteratorType) { return iteratorType == getParallelIteratorTypeName(); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 00352c43bb07..e6ab837141f1 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -624,3 +624,52 @@ func.func @reduce_different_output_shapes(%input1: tensor<16x32x64xf32>, } func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32> } + +// ----- + +func.func @transpose_invalid_permutation(%input: tensor<16x32x64xf32>, + %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { + // expected-error @+1 {{'linalg.transpose' op permutation is not valid}} + %transpose = linalg.transpose + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<32x64x16xf32>) + permutation = [1, 1, 2] + func.return %transpose : tensor<32x64x16xf32> +} + +// ----- + +func.func @transpose_permutated_dims_mismatch(%input: tensor<16x32x64xf32>, + %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { + // expected-error @+1 {{'linalg.transpose' op dim(result, 0) = 32 doesn't match dim(input, permutation[0]) = 16}} + %transpose = linalg.transpose + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<32x64x16xf32>) + permutation = [0, 1, 2] + func.return %transpose : tensor<32x64x16xf32> +} + +// ----- + +func.func @transpose_rank_permutation_size_mismatch( + %input: tensor<16x32x64xf32>, + %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { + // expected-error @+1 {{'linalg.transpose' op size of permutation 2 does not match the argument rank 3}} + %transpose = linalg.transpose + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<32x64x16xf32>) + permutation = [1, 0] + func.return %transpose : tensor<32x64x16xf32> +} + +// ----- + +func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>, + %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { + // expected-error @+1 {{'linalg.transpose' op input rank 2 does not match init rank 3}} + %transpose = linalg.transpose + ins(%input:tensor<16x32xf32>) + outs(%init:tensor<32x64x16xf32>) + permutation = [1, 0, 2] + func.return %transpose : tensor<32x64x16xf32> +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index f751ddff7df0..4bea3f6d3837 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -67,11 +67,11 @@ func.func @fill_view(%arg0: memref>, %arg1: f32) // ----- -func.func @transpose(%arg0: memref>) { +func.func @memref_transpose(%arg0: memref>) { %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref> to memref> return } -// CHECK-LABEL: func @transpose +// CHECK-LABEL: func @memref_transpose // CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : // CHECK-SAME: memref> to memref> @@ -457,3 +457,27 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>, } // CHECK-LABEL: func @variadic_reduce_memref // CHECK: linalg.reduce + +// ----- + +func.func @transpose(%input: tensor<16x32x64xf32>, + %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { + %transpose = linalg.transpose + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<32x64x16xf32>) + permutation = [1, 2, 0] + func.return %transpose : tensor<32x64x16xf32> +} +// CHECK-LABEL: func @transpose + +// ----- + +func.func @transpose_memref(%input: memref<16x32x64xf32>, + %init: memref<32x64x16xf32>) { + linalg.transpose + ins(%input:memref<16x32x64xf32>) + outs(%init:memref<32x64x16xf32>) + permutation = [1, 2, 0] + func.return +} +// CHECK-LABEL: func @transpose_memref