[mlir][Linalg] Add a linalg.reshape op

Summary:
This diff adds a new operation to linalg to allow reshaping of an
existing view into a new view in the same buffer at the same offset.

More specifically:
The `linalg.reshape` op produces a new view whose sizes are a reassociation
of the original `view`. Depending on whether or not the reassociated
MemRefType is contiguous, the resulting memref may require explicit alloc
and copies.

A reassociation is defined as a continous grouping of dimensions and is
represented with a affine map array attribute. In the future, non-continous
groupings may be allowed (i.e. permutations, reindexings etc).

For now, it is assumed that either:
  1. a reassociation produces and consumes contiguous MemRefType or,
  2. the reshape op will be folded into its consumers (by changing the shape
     of the computations).
All other cases are undefined behavior and a reshape op may not lower to
LLVM if it cannot be proven statically that it does not require alloc+copy.

A reshape may either collapse or expand dimensions, depending on the
relationship between source and target memref ranks. The verification rule
is that the reassociation maps are applied to the memref with the larger
rank to obtain the memref with the smaller rank. In the case of a dimension
expansion, the reassociation maps can be interpreted as inverse maps.

Examples:

```mlir
   // Dimension collapse (i, j) -> i' and k -> k'
   %1 = linalg.reshape %0 [(i, j, k) -> (i, j),
                           (i, j, k) -> (k)] :
     memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
```

```mlir
   // Dimension expansion i -> (i', j') and (k) -> (k')
   %1 = linalg.reshape %0 [(i, j, k) -> (i, j),
                           (i, j, k) -> (k)] :
     memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
```

The relevant invalid and roundtripping tests are added.

Reviewers: AlexEichenberger, ftynse, rriddle, asaadaldien, yangjunpro

Subscribers: kiszk, merge_guards_bot, mehdi_amini, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72168
This commit is contained in:
Nicolas Vasilache 2020-01-06 22:14:14 -05:00
parent d877229b5b
commit e3750cafdb
8 changed files with 458 additions and 15 deletions

View File

@ -17,6 +17,7 @@ namespace edsc {
namespace intrinsics { namespace intrinsics {
using linalg_fill = OperationBuilder<linalg::FillOp>; using linalg_fill = OperationBuilder<linalg::FillOp>;
using linalg_reshape = OperationBuilder<linalg::ReshapeOp>;
using linalg_yield = OperationBuilder<linalg::YieldOp>; using linalg_yield = OperationBuilder<linalg::YieldOp>;
} // namespace intrinsics } // namespace intrinsics

View File

@ -58,6 +58,58 @@ def Linalg_RangeOp :
let verifier = ?; let verifier = ?;
} }
def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
Arguments<(ins AnyStridedMemRef:$view, AffineMapArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef)> {
let summary = "linalg.reshape produces a new view into the operand view";
let description = [{
The `linalg.reshape` op produces a new view whose sizes are a reassociation
of the original `view`. Depending on whether or not the reassociated
MemRefType is contiguous, the resulting memref may require explicit alloc
and copies.
A reassociation is defined as a continous grouping of dimensions and is
represented with an affine map array attribute. In the future, non-continous
groupings may be allowed (i.e. permutations, reindexings etc).
For now, it is assumed that either:
1. a reassociation produces and consumes contiguous MemRefType or,
2. the reshape op will be folded into its consumers (by changing the shape
of the computations).
All other cases are undefined behavior and a reshape op may not lower to
LLVM if it cannot be proven statically that it does not require alloc+copy.
A reshape may either collapse or expand dimensions, depending on the
relationship between source and target memref ranks. The verification rule
is that the reassociation maps are applied to the memref with the larger
rank to obtain the memref with the smaller rank. In the case of a dimension
expansion, the reassociation maps can be interpreted as inverse maps.
Examples:
```mlir
// Dimension collapse (i, j) -> i' and k -> k'
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
```
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
```
}];
let builders = [OpBuilder<
"Builder *b, OperationState &result, Value view, "
"ArrayAttr reassociation, ArrayRef<NamedAttribute> attrs = {}">];
let extraClassDeclaration = [{
static StringRef getReassociationAttrName() { return "reassociation"; }
MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
}];
}
def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
Arguments<(ins AnyStridedMemRef:$view, Arguments<(ins AnyStridedMemRef:$view,
Variadic<AnyTypeOf<[Range, Index]>>:$indexings)>, Variadic<AnyTypeOf<[Range, Index]>>:$indexings)>,

View File

@ -87,6 +87,7 @@ public:
template <typename U> bool isa() const; template <typename U> bool isa() const;
template <typename U> U dyn_cast() const; template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const; template <typename U> U cast() const;
MLIRContext *getContext() const; MLIRContext *getContext() const;
@ -226,25 +227,23 @@ AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr); raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr);
template <typename U> bool AffineExpr::isa() const { template <typename U> bool AffineExpr::isa() const {
if (std::is_same<U, AffineBinaryOpExpr>::value) { if (std::is_same<U, AffineBinaryOpExpr>::value)
return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP; return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
} if (std::is_same<U, AffineDimExpr>::value)
if (std::is_same<U, AffineDimExpr>::value) {
return getKind() == AffineExprKind::DimId; return getKind() == AffineExprKind::DimId;
} if (std::is_same<U, AffineSymbolExpr>::value)
if (std::is_same<U, AffineSymbolExpr>::value) {
return getKind() == AffineExprKind::SymbolId; return getKind() == AffineExprKind::SymbolId;
} if (std::is_same<U, AffineConstantExpr>::value)
if (std::is_same<U, AffineConstantExpr>::value) {
return getKind() == AffineExprKind::Constant; return getKind() == AffineExprKind::Constant;
}
} }
template <typename U> U AffineExpr::dyn_cast() const { template <typename U> U AffineExpr::dyn_cast() const {
if (isa<U>()) { if (isa<U>())
return U(expr); return U(expr);
}
return U(nullptr); return U(nullptr);
} }
template <typename U> U AffineExpr::dyn_cast_or_null() const {
return (!*this || !isa<U>()) ? U(nullptr) : U(expr);
}
template <typename U> U AffineExpr::cast() const { template <typename U> U AffineExpr::cast() const {
assert(isa<U>()); assert(isa<U>());
return U(expr); return U(expr);

View File

@ -16,6 +16,7 @@ struct fltSemantics;
} // namespace llvm } // namespace llvm
namespace mlir { namespace mlir {
class AffineExpr;
class AffineMap; class AffineMap;
class FloatType; class FloatType;
class IndexType; class IndexType;
@ -245,6 +246,9 @@ public:
/// Whether the given dimension size indicates a dynamic dimension. /// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; } static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
return dStrideOrOffset == kDynamicStrideOrOffset;
}
}; };
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
@ -548,6 +552,9 @@ public:
LogicalResult getStridesAndOffset(MemRefType t, LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides, SmallVectorImpl<int64_t> &strides,
int64_t &offset); int64_t &offset);
LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<AffineExpr> &strides,
AffineExpr &offset);
/// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
/// represents a dynamic value), return the single result AffineMap which /// represents a dynamic value), return the single result AffineMap which
@ -569,6 +576,13 @@ LogicalResult getStridesAndOffset(MemRefType t,
AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset, AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
MLIRContext *context); MLIRContext *context);
/// Return a version of `t` with identity layout if it can be determined
/// statically that the layout is the canonical contiguous strided layout.
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
/// `t` with simplifed layout.
MemRefType canonicalizeStridedLayout(MemRefType t);
/// Return true if the layout for `t` is compatible with strided semantics.
bool isStrided(MemRefType t); bool isStrided(MemRefType t);
} // end namespace mlir } // end namespace mlir

View File

@ -6,7 +6,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
// This file implements a the Linalg operations. // This file implements the Linalg operations.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -23,6 +23,7 @@
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
@ -332,6 +333,206 @@ static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(type, result.types)); parser.addTypeToList(type, result.types));
} }
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
/// Return true if the reassociation specification is valid, false otherwise.
/// When false, the `invalidIndex` integer pointer is optionally filled with the
/// index of the offending reassociation map.
static bool isReassociationValid(ArrayRef<AffineMap> reassociation,
int *invalidIndex = nullptr) {
if (reassociation.empty())
return true;
unsigned nDims = reassociation[0].getNumDims();
unsigned nextExpectedDim = 0;
for (auto it : llvm::enumerate(reassociation)) {
auto m = it.value();
if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
if (invalidIndex)
*invalidIndex = it.index();
return false;
}
for (auto e : m.getResults()) {
auto d = e.dyn_cast<AffineDimExpr>();
if (!d || d.getPosition() != nextExpectedDim++) {
if (invalidIndex)
*invalidIndex = it.index();
return false;
}
}
}
if (nextExpectedDim != nDims) {
if (invalidIndex)
*invalidIndex = reassociation.size() - 1;
return false;
}
return true;
}
/// Detect whether memref dims [dim, dim + extent) can be reshaped without
/// copies.
static bool isReshapableDimBand(unsigned dim, unsigned extent,
ArrayRef<int64_t> sizes,
ArrayRef<AffineExpr> strides) {
assert(sizes.size() == strides.size() && "mismatched ranks");
// off by 1 indexing to avoid out of bounds
// V
for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
// Only bands of static shapes are reshapable. This is due to the fact that
// there is no relation between dynamic sizes and dynamic strides: we do not
// have enough information to know whether a "-1" size corresponds to the
// proper symbol in the AffineExpr of a stride.
if (ShapedType::isDynamic(sizes[dim + 1]))
return false;
// TODO(ntv) Refine this by passing the proper nDims and nSymbols so we can
// simplify on the fly and catch more reshapable cases.
if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
return false;
}
return true;
}
/// Compute the MemRefType obtained by applying the `reassociation` (which is
/// expected to be valid) to `type`.
/// If `type` is Contiguous MemRefType, this always produce a contiguous
/// MemRefType.
static MemRefType
computeReshapeCollapsedType(MemRefType type,
ArrayRef<AffineMap> reassociation) {
auto sizes = type.getShape();
AffineExpr offset;
SmallVector<AffineExpr, 4> strides;
auto status = getStridesAndOffset(type, strides, offset);
(void)status;
assert(succeeded(status) && "expected strided memref");
SmallVector<int64_t, 4> newSizes;
newSizes.reserve(reassociation.size());
SmallVector<AffineExpr, 4> newStrides;
newStrides.reserve(reassociation.size());
// Use the fact that reassociation is valid to simplify the logic: only use
// each map's rank.
assert(isReassociationValid(reassociation) && "invalid reassociation");
unsigned currentDim = 0;
for (AffineMap m : reassociation) {
unsigned dim = m.getNumResults();
int64_t size = 1;
AffineExpr stride = strides[currentDim + dim - 1];
if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
size = ShapedType::kDynamicSize;
stride = AffineExpr();
} else {
for (unsigned d = 0; d < dim; ++d)
size *= sizes[currentDim + d];
}
newSizes.push_back(size);
newStrides.push_back(stride);
currentDim += dim;
}
// Early-exit: if `type` is contiguous, the result must be contiguous.
if (canonicalizeStridedLayout(type).getAffineMaps().empty())
return MemRefType::get(newSizes, type.getElementType(), {});
// Convert back to int64_t because we don't have enough information to create
// new strided layouts from AffineExpr only. This corresponds to a case where
// copies may be necessary.
int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
if (auto o = offset.dyn_cast<AffineConstantExpr>())
intOffset = o.getValue();
SmallVector<int64_t, 4> intStrides;
intStrides.reserve(strides.size());
for (auto stride : newStrides) {
if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
intStrides.push_back(cst.getValue());
else
intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
}
auto layout =
makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
return canonicalizeStridedLayout(
MemRefType::get(newSizes, type.getElementType(), {layout}));
}
/// Helper functions assert Attribute of the proper type in attr and returns the
/// corresponding vector.
/// TODO(rridle,ntv) this should be evolved into a generic
/// `getRangeOfType<AffineMap>(ArrayAttr attrs)` that does not copy.
static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
return functional::map(
[](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }, attrs);
}
void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result,
Value view, ArrayAttr reassociation,
ArrayRef<NamedAttribute> attrs) {
auto maps = getAffineMaps(reassociation);
auto memRefType = view.getType().cast<MemRefType>();
auto resultType = computeReshapeCollapsedType(memRefType, maps);
build(b, result, resultType, view, attrs);
result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation);
}
static void print(OpAsmPrinter &p, ReshapeOp op) {
p << op.getOperationName() << " " << op.view() << " " << op.reassociation();
p.printOptionalAttrDict(op.getAttrs(),
{ReshapeOp::getReassociationAttrName()});
p << " : " << op.getViewType() << " into " << op.getResult().getType();
}
static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType view;
ArrayAttr reassociation;
MemRefType type, resultType;
return failure(parser.parseOperand(view) ||
parser.parseAttribute(reassociation,
ReshapeOp::getReassociationAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.parseKeywordType("into", resultType) ||
parser.resolveOperand(view, type, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
static LogicalResult verify(ReshapeOp op) {
MemRefType expandedType = op.getViewType();
MemRefType collapsedType = op.getResult().getType().cast<MemRefType>();
unsigned expandedRank = expandedType.getRank();
unsigned collapsedRank = collapsedType.getRank();
bool isCollapse = expandedRank > collapsedRank;
if (!isCollapse) {
std::swap(expandedRank, collapsedRank);
std::swap(expandedType, collapsedType);
}
if (expandedRank == 0 || collapsedRank == 0)
return op.emitOpError("expected non-zero memref ranks");
if (expandedRank == collapsedRank)
return op.emitOpError("expected to collapse or expand dims");
if (collapsedRank != op.reassociation().size())
return op.emitOpError("expected rank of the collapsed view(")
<< collapsedRank << ") to be the number of reassociation maps("
<< op.reassociation().size() << ")";
auto maps = getAffineMaps(op.reassociation());
for (auto it : llvm::enumerate(maps))
if (it.value().getNumDims() != expandedRank)
return op.emitOpError("expected reassociation map #")
<< it.index() << " of same rank as expanded memref("
<< expandedRank << "), but got " << it.value().getNumDims();
int invalidIdx = 0;
if (!isReassociationValid(maps, &invalidIdx))
return op.emitOpError("expected reassociation map #")
<< invalidIdx << " to be valid and contiguous";
MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
if (collapsedType != expectedType)
return op.emitOpError("expected collapsed type to be ")
<< expectedType << ", but got " << collapsedType;
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SliceOp // SliceOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -520,9 +520,9 @@ static LogicalResult extractStrides(AffineExpr e,
llvm_unreachable("unexpected binary operation"); llvm_unreachable("unexpected binary operation");
} }
static LogicalResult getStridesAndOffset(MemRefType t, LogicalResult mlir::getStridesAndOffset(MemRefType t,
SmallVectorImpl<AffineExpr> &strides, SmallVectorImpl<AffineExpr> &strides,
AffineExpr &offset) { AffineExpr &offset) {
auto affineMaps = t.getAffineMaps(); auto affineMaps = t.getAffineMaps();
// For now strides are only computed on a single affine map with a single // For now strides are only computed on a single affine map with a single
// result (i.e. the closed subset of linearization maps that are compatible // result (i.e. the closed subset of linearization maps that are compatible
@ -699,6 +699,38 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
return AffineMap::get(strides.size(), nSymbols, expr); return AffineMap::get(strides.size(), nSymbols, expr);
} }
/// Return a version of `t` with identity layout if it can be determined
/// statically that the layout is the canonical contiguous strided layout.
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
/// `t` with simplifed layout.
/// If `t` has multiple layout maps or a multi-result layout, just return `t`.
MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
auto affineMaps = t.getAffineMaps();
// Already in canonical form.
if (affineMaps.empty())
return t;
// Can't reduce to canonical identity form, return in canonical form.
if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
return t;
// If the canonical strided layout for the sizes of `t` is equal to the
// simplified layout of `t` we can just return an empty layout. Otherwise,
// just simplify the existing layout.
AffineExpr expr =
makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
auto m = affineMaps[0];
auto simplifiedLayoutExpr =
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
if (expr != simplifiedLayoutExpr)
return MemRefType::get(t.getShape(), t.getElementType(),
{AffineMap::get(m.getNumDims(), m.getNumSymbols(),
{simplifiedLayoutExpr})});
return MemRefType::get(t.getShape(), t.getElementType(), {});
}
/// Return true if the layout for `t` is compatible with strided semantics.
bool mlir::isStrided(MemRefType t) { bool mlir::isStrided(MemRefType t) {
int64_t offset; int64_t offset;
SmallVector<int64_t, 4> stridesAndOffset; SmallVector<int64_t, 4> stridesAndOffset;

View File

@ -482,3 +482,49 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
// expected-error @+1 {{expected valid keyword}} // expected-error @+1 {{expected valid keyword}}
!invalid_type = type !linalg<"?"> !invalid_type = type !linalg<"?">
// -----
func @reshape(%arg0: memref<f32>) {
// expected-error @+1 {{expected non-zero memref ranks}}
%0 = linalg.reshape %arg0 [()->(0)] : memref<f32> into memref<f32>
}
// -----
func @reshape(%arg0: memref<?xf32>) {
// expected-error @+1 {{expected to collapse or expand dims}}
%0 = linalg.reshape %arg0 [(i)->(i)] : memref<?xf32> into memref<?xf32>
}
// -----
func @reshape(%arg0: memref<?x?x?xf32>) {
// expected-error @+1 {{expected rank of the collapsed view(2) to be the number of reassociation maps(1)}}
%0 = linalg.reshape %arg0 [(i, j, k) -> (i, j)] :
memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
}
// -----
func @reshape(%arg0: memref<?x?x?xf32>) {
// expected-error @+1 {{expected reassociation map #0 of same rank as expanded memref(3), but got 1}}
%0 = linalg.reshape %arg0 [(i) -> (i), (i, j, k) -> (k)] :
memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
}
// -----
func @reshape(%arg0: memref<?x?x?xf32>) {
// expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}}
%0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), (i, j, k) -> (k, j)] :
memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
}
// -----
func @reshape(%arg0: memref<?x?x?xf32>) {
// expected-error @+1 {{expected collapsed type to be 'memref<?x?xf32>', but got 'memref<?x?xf32, (d0, d1)[s0] -> (d0 * s0 + d1)>'}}
%0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
memref<?x?x?xf32> into memref<?x?xf32, (d0, d1)[s0] -> (d0 * s0 + d1)>
}

View File

@ -7,12 +7,23 @@
// CHECK-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) // CHECK-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0)
// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) // CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
// CHECK-DAG: #[[strided2DOFF0:.*]] = (d0, d1)[s0] -> (d0 * s0 + d1)
// CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) // CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)
// CHECK-DAG: #[[strided3DOFF0:.*]] = (d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)
// CHECK-DAG: #[[strided6D:.*]] = (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5) // CHECK-DAG: #[[strided6D:.*]] = (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)
// CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1) // CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
// CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0) // CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
// CHECK-DAG: #[[reshapeD01:.*]] = (d0, d1, d2) -> (d0, d1)
// CHECK-DAG: #[[reshapeD2:.*]] = (d0, d1, d2) -> (d2)
// CHECK-DAG: #[[reshapeD0:.*]] = (d0, d1, d2) -> (d0)
// CHECK-DAG: #[[reshapeD12:.*]] = (d0, d1, d2) -> (d1, d2)
// CHECK-DAG: #[[reshapeD012:.*]] = (d0, d1, d2) -> (d0, d1, d2)
// CHECK-DAG: #[[reshape5D01:.*]] = (d0, d1, d2, d3, d4) -> (d0, d1)
// CHECK-DAG: #[[reshape5D2:.*]] = (d0, d1, d2, d3, d4) -> (d2)
// CHECK-DAG: #[[reshape5D34:.*]] = (d0, d1, d2, d3, d4) -> (d3, d4)
func @range(%arg0: index, %arg1: index, %arg2: index) { func @range(%arg0: index, %arg1: index, %arg2: index) {
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
return return
@ -181,7 +192,6 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors // CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors
// CHECK: linalg.yield %{{.*}} : f32 // CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]> // CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) { %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.indexed_generic #trait2 %arg0, %arg1 { linalg.indexed_generic #trait2 %arg0, %arg1 {
@ -195,3 +205,91 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32 // CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]> // CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
func @reshape_static(%arg0: memref<3x4x5xf32>) {
// Reshapes that collapse and expand back a contiguous tensor.
%0 = linalg.reshape %arg0 [(i, j, k) -> (i, j),
(i, j, k) -> (k)] :
memref<3x4x5xf32> into memref<12x5xf32>
%r0 = linalg.reshape %0 [(i, j, k) -> (i, j),
(i, j, k) -> (k)] :
memref<12x5xf32> into memref<3x4x5xf32>
%1 = linalg.reshape %arg0 [(i, j, k) -> (i),
(i, j, k) -> (j, k)] :
memref<3x4x5xf32> into memref<3x20xf32>
%r1 = linalg.reshape %1 [(i, j, k) -> (i),
(i, j, k) -> (j, k)] :
memref<3x20xf32> into memref<3x4x5xf32>
%2 = linalg.reshape %arg0 [(i, j, k) -> (i, j, k)] :
memref<3x4x5xf32> into memref<60xf32>
%r2 = linalg.reshape %2 [(i, j, k) -> (i, j, k)] :
memref<60xf32> into memref<3x4x5xf32>
// Reshapes that expand and collapse back a contiguous tensor with some 1's.
%3 = linalg.reshape %arg0 [(i, j, k, l, m) -> (i, j),
(i, j, k, l, m) -> (k),
(i, j, k, l, m) -> (l, m)] :
memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
%r3 = linalg.reshape %3 [(i, j, k, l, m) -> (i, j),
(i, j, k, l, m) -> (k),
(i, j, k, l, m) -> (l, m)] :
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
return
}
// CHECK-LABEL: func @reshape_static
// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD0]], #[[reshapeD12]]]
// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD0]], #[[reshapeD12]]]
// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD012]]]
// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD012]]]
// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]]
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]]
// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
%arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
%0 = linalg.reshape %arg0 [(i, j, k) -> (i, j),
(i, j, k) -> (k)] :
memref<?x?x?xf32> into memref<?x?xf32>
%r0 = linalg.reshape %0 [(i, j, k) -> (i, j),
(i, j, k) -> (k)] :
memref<?x?xf32> into memref<?x?x?xf32>
%1 = linalg.reshape %arg1 [(i, j, k) -> (i, j),
(i, j, k) -> (k)] :
memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
memref<?x?xf32, offset : 0, strides : [?, 1]>
%r1 = linalg.reshape %1 [(i, j, k) -> (i, j),
(i, j, k) -> (k)] :
memref<?x?xf32, offset : 0, strides : [?, 1]> into
memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>
%2 = linalg.reshape %arg2 [(i, j, k) -> (i, j),
(i, j, k) -> (k)] :
memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
memref<?x?xf32, offset : ?, strides : [?, 1]>
%r2 = linalg.reshape %2 [(i, j, k) -> (i, j),
(i, j, k) -> (k)] :
memref<?x?xf32, offset : ?, strides : [?, 1]> into
memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>
return
}
// CHECK-LABEL: func @reshape
// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
// CHECK-SAME: memref<?x?xf32> into memref<?x?x?xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
// CHECK-SAME: memref<?x?x?xf32, #[[strided3DOFF0]]> into memref<?x?xf32, #[[strided2DOFF0]]>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
// CHECK-SAME: memref<?x?xf32, #[[strided2DOFF0]]> into memref<?x?x?xf32, #[[strided3DOFF0]]>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
// CHECK-SAME: memref<?x?x?xf32, #[[strided3D]]> into memref<?x?xf32, #[[strided2D]]>
// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]]
// CHECK-SAME: memref<?x?xf32, #[[strided2D]]> into memref<?x?x?xf32, #[[strided3D]]>