mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-30 17:21:10 +00:00
[mlir] Add TransposeOp to Linalg structured ops.
RFC: https://discourse.llvm.org/t/rfc-primitive-ops-add-mapop-reductionop-transposeop-broadcastop-to-linalg/64184 Differential Revision: https://reviews.llvm.org/D135854
This commit is contained in:
parent
1625224fbb
commit
d261aa88f8
@ -70,6 +70,10 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
|
||||
SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
|
||||
ArrayRef<AffineExpr> b);
|
||||
|
||||
/// Check if `permutation` is a permutation of the range
|
||||
/// `[0, permutation.size())`.
|
||||
bool isPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -360,6 +360,78 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transpose op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TransposeOp : LinalgStructuredBase_Op<"transpose", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
SameVariadicOperandSize,
|
||||
SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let summary = "Transpose operator";
|
||||
let description = [{
|
||||
Permutes the dimensions of `input` according to the given `permutation`.
|
||||
`dim(result, i) = dim(input, permutation[i])`
|
||||
|
||||
This op actually moves data, unlike `memref.transpose` which is a metadata
|
||||
operation only that produces a transposed "view".
|
||||
|
||||
Example:
|
||||
```
|
||||
%transpose = linalg.transpose
|
||||
ins(%input:tensor<16x64xf32>)
|
||||
outs(%init:tensor<64x16xf32>)
|
||||
permutation = [1, 0]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
// Input arg
|
||||
TensorOrMemref:$input,
|
||||
// Output arg
|
||||
TensorOrMemref:$init,
|
||||
|
||||
DenseI64ArrayAttr:$permutation
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$result);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$input, "Value":$init,
|
||||
"DenseI64ArrayAttr":$permutation, CArg<"ArrayRef<NamedAttribute>",
|
||||
"{}">:$attributes)>,
|
||||
OpBuilder<(ins "Value":$input, "Value":$init,
|
||||
"ArrayRef<int64_t>":$permutation, CArg<"ArrayRef<NamedAttribute>",
|
||||
"{}">:$attributes)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = structuredOpsBaseDecls # [{
|
||||
// Declare functions necessary for LinalgStructuredInterface.
|
||||
SmallVector<StringRef> getIteratorTypesArray();
|
||||
ArrayAttr getIndexingMaps();
|
||||
std::string getLibraryCallName() {
|
||||
return "op_has_no_registered_library_name";
|
||||
}
|
||||
|
||||
// Implement functions necessary for DestinationStyleOpInterface.
|
||||
std::pair<int64_t, int64_t> getOutputsPositionRange() {
|
||||
int64_t getNumOperands = this->getNumOperands();
|
||||
return {getNumOperands - 1, getNumOperands};
|
||||
}
|
||||
|
||||
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
getRegionBuilder();
|
||||
|
||||
static void createRegion(::mlir::OpBuilder &opBuilder,
|
||||
::mlir::OperationState & odsState);
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Named Linalg ops, implemented as a declarative configurations of generic ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -41,10 +41,6 @@ bool hasOnlyScalarElementwiseOp(Region &r);
|
||||
/// Check if a LinalgOp is an element-wise operation.
|
||||
bool isElementwise(LinalgOp op);
|
||||
|
||||
/// Check if `permutation` is a permutation of the range
|
||||
/// `[0, permutation.size())`.
|
||||
bool isPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
/// Check if iterator type has "parallel" semantics.
|
||||
bool isParallelIterator(StringRef iteratorType);
|
||||
|
||||
|
@ -1601,6 +1601,142 @@ LogicalResult ReduceOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
TransposeOp::getRegionBuilder() {
|
||||
return [](mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>) {
|
||||
b.create<linalg::YieldOp>(block.getArguments().back());
|
||||
};
|
||||
}
|
||||
|
||||
void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
|
||||
::mlir::OperationState &odsState) {
|
||||
Region *region = odsState.addRegion();
|
||||
|
||||
SmallVector<Type> argTypes;
|
||||
SmallVector<Location> argLocs;
|
||||
for (auto t : odsState.operands) {
|
||||
argTypes.push_back(getElementTypeOrSelf(t));
|
||||
argLocs.push_back(opBuilder.getUnknownLoc());
|
||||
}
|
||||
|
||||
// RAII.
|
||||
OpBuilder::InsertionGuard guard(opBuilder);
|
||||
Block *body =
|
||||
opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
|
||||
|
||||
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
||||
getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
|
||||
}
|
||||
|
||||
void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
|
||||
::mlir::OperationState &odsState, Value input,
|
||||
Value init, DenseI64ArrayAttr permutation,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
odsState.addOperands(input);
|
||||
odsState.addOperands(init);
|
||||
odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
|
||||
odsState.addAttributes(attributes);
|
||||
odsState.addTypes(init.getType());
|
||||
|
||||
createRegion(odsBuilder, odsState);
|
||||
}
|
||||
|
||||
void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
|
||||
::mlir::OperationState &odsState, Value input,
|
||||
Value init, ArrayRef<int64_t> permutation,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
build(odsBuilder, odsState, input, init,
|
||||
odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
|
||||
}
|
||||
|
||||
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
if (failed(parseDstStyleOp(
|
||||
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
||||
return parseDenseI64ArrayAttr(parser, attributes, "permutation");
|
||||
})))
|
||||
return failure();
|
||||
|
||||
OpBuilder opBuilder(parser.getContext());
|
||||
createRegion(opBuilder, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
void TransposeOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
if (!getResults().empty())
|
||||
setNameFn(getResults().front(), "transposed");
|
||||
}
|
||||
|
||||
void TransposeOp::print(OpAsmPrinter &p) {
|
||||
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
|
||||
}
|
||||
|
||||
LogicalResult TransposeOp::verify() {
|
||||
ArrayRef<int64_t> permutationRef = getPermutation();
|
||||
|
||||
if (!isPermutation(permutationRef))
|
||||
return emitOpError("permutation is not valid");
|
||||
|
||||
auto inputType = getInput().getType();
|
||||
auto initType = getInit().getType();
|
||||
|
||||
int64_t rank = inputType.getRank();
|
||||
|
||||
if (rank != initType.getRank())
|
||||
return emitOpError() << "input rank " << rank
|
||||
<< " does not match init rank " << initType.getRank();
|
||||
|
||||
if (rank != static_cast<int64_t>(permutationRef.size()))
|
||||
return emitOpError() << "size of permutation " << permutationRef.size()
|
||||
<< " does not match the argument rank " << rank;
|
||||
|
||||
auto inputDims = inputType.getShape();
|
||||
auto initDims = initType.getShape();
|
||||
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
int64_t inputDim = inputDims[permutationRef[i]];
|
||||
int64_t initDim = initDims[i];
|
||||
|
||||
if (inputDim != initDim) {
|
||||
return emitOpError() << "dim(result, " << i << ") = " << initDim
|
||||
<< " doesn't match dim(input, permutation[" << i
|
||||
<< "]) = " << inputDim;
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<StringRef> TransposeOp::getIteratorTypesArray() {
|
||||
int64_t rank = getInit().getType().getRank();
|
||||
return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
|
||||
}
|
||||
|
||||
ArrayAttr TransposeOp::getIndexingMaps() {
|
||||
Builder builder(getContext());
|
||||
int64_t rank = getInit().getType().getRank();
|
||||
return builder.getAffineMapArrayAttr(
|
||||
{builder.getMultiDimIdentityMap(rank),
|
||||
AffineMap::getPermutationMap(
|
||||
llvm::to_vector_of<unsigned>(getPermutation()), getContext())});
|
||||
}
|
||||
|
||||
void TransposeOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
getGenericEffectsImpl(effects, getOperation()->getResults(),
|
||||
getInputOperands(), getOutputOperands());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1710,6 +1846,19 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
|
||||
return llvm::to_vector<4>(concatRanges);
|
||||
}
|
||||
|
||||
bool mlir::linalg::isPermutation(ArrayRef<int64_t> permutation) {
|
||||
// Count the number of appearances for all indices.
|
||||
SmallVector<int64_t> indexCounts(permutation.size(), 0);
|
||||
for (auto index : permutation) {
|
||||
// Exit if the index is out-of-range.
|
||||
if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
|
||||
return false;
|
||||
++indexCounts[index];
|
||||
}
|
||||
// Return true if all indices appear once.
|
||||
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
|
||||
}
|
||||
|
||||
static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
|
||||
if (auto memref = t.dyn_cast<MemRefType>()) {
|
||||
ss << "view";
|
||||
|
@ -186,19 +186,6 @@ bool isElementwise(LinalgOp op) {
|
||||
return hasOnlyScalarElementwiseOp(op->getRegion(0));
|
||||
}
|
||||
|
||||
bool isPermutation(ArrayRef<int64_t> permutation) {
|
||||
// Count the number of appearances for all indices.
|
||||
SmallVector<int64_t> indexCounts(permutation.size(), 0);
|
||||
for (auto index : permutation) {
|
||||
// Exit if the index is out-of-range.
|
||||
if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
|
||||
return false;
|
||||
indexCounts[index]++;
|
||||
}
|
||||
// Return true if all indices appear once.
|
||||
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
|
||||
}
|
||||
|
||||
bool isParallelIterator(StringRef iteratorType) {
|
||||
return iteratorType == getParallelIteratorTypeName();
|
||||
}
|
||||
|
@ -624,3 +624,52 @@ func.func @reduce_different_output_shapes(%input1: tensor<16x32x64xf32>,
|
||||
}
|
||||
func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_invalid_permutation(%input: tensor<16x32x64xf32>,
|
||||
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
|
||||
// expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
|
||||
%transpose = linalg.transpose
|
||||
ins(%input:tensor<16x32x64xf32>)
|
||||
outs(%init:tensor<32x64x16xf32>)
|
||||
permutation = [1, 1, 2]
|
||||
func.return %transpose : tensor<32x64x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_permutated_dims_mismatch(%input: tensor<16x32x64xf32>,
|
||||
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
|
||||
// expected-error @+1 {{'linalg.transpose' op dim(result, 0) = 32 doesn't match dim(input, permutation[0]) = 16}}
|
||||
%transpose = linalg.transpose
|
||||
ins(%input:tensor<16x32x64xf32>)
|
||||
outs(%init:tensor<32x64x16xf32>)
|
||||
permutation = [0, 1, 2]
|
||||
func.return %transpose : tensor<32x64x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_rank_permutation_size_mismatch(
|
||||
%input: tensor<16x32x64xf32>,
|
||||
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
|
||||
// expected-error @+1 {{'linalg.transpose' op size of permutation 2 does not match the argument rank 3}}
|
||||
%transpose = linalg.transpose
|
||||
ins(%input:tensor<16x32x64xf32>)
|
||||
outs(%init:tensor<32x64x16xf32>)
|
||||
permutation = [1, 0]
|
||||
func.return %transpose : tensor<32x64x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
|
||||
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
|
||||
// expected-error @+1 {{'linalg.transpose' op input rank 2 does not match init rank 3}}
|
||||
%transpose = linalg.transpose
|
||||
ins(%input:tensor<16x32xf32>)
|
||||
outs(%init:tensor<32x64x16xf32>)
|
||||
permutation = [1, 0, 2]
|
||||
func.return %transpose : tensor<32x64x16xf32>
|
||||
}
|
||||
|
@ -67,11 +67,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
|
||||
func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
|
||||
%0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @transpose
|
||||
// CHECK-LABEL: func @memref_transpose
|
||||
// CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
|
||||
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
|
||||
|
||||
@ -457,3 +457,27 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
|
||||
}
|
||||
// CHECK-LABEL: func @variadic_reduce_memref
|
||||
// CHECK: linalg.reduce
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose(%input: tensor<16x32x64xf32>,
|
||||
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
|
||||
%transpose = linalg.transpose
|
||||
ins(%input:tensor<16x32x64xf32>)
|
||||
outs(%init:tensor<32x64x16xf32>)
|
||||
permutation = [1, 2, 0]
|
||||
func.return %transpose : tensor<32x64x16xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @transpose
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_memref(%input: memref<16x32x64xf32>,
|
||||
%init: memref<32x64x16xf32>) {
|
||||
linalg.transpose
|
||||
ins(%input:memref<16x32x64xf32>)
|
||||
outs(%init:memref<32x64x16xf32>)
|
||||
permutation = [1, 2, 0]
|
||||
func.return
|
||||
}
|
||||
// CHECK-LABEL: func @transpose_memref
|
||||
|
Loading…
Reference in New Issue
Block a user