[MLIR] Give AffineStoreOp and AffineLoadOp Memory SideEffects.

Summary:
This change results in tests also being changed to prevent dead
affine.load operations from being folded away during rewrites.

Also move AffineStoreOp and AffineLoadOp to an ODS file.

Differential Revision: https://reviews.llvm.org/D78930
This commit is contained in:
Tres Popp 2020-04-27 16:40:00 +02:00
parent d0846b432c
commit f66c87637a
8 changed files with 266 additions and 240 deletions

View File

@ -316,147 +316,6 @@ public:
SmallVectorImpl<OpFoldResult> &results);
};
/// The "affine.load" op reads an element from a memref, where the index
/// for each memref dimension is an affine expression of loop induction
/// variables and symbols. The output of 'affine.load' is a new value with the
/// same type as the elements of the memref. An affine expression of loop IVs
/// and symbols must be specified for each dimension of the memref. The keyword
/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
//
// Example 1:
//
// %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
//
// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
//
// %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)]
// : memref<100x100xf32>
//
class AffineLoadOp : public Op<AffineLoadOp, OpTrait::OneResult,
OpTrait::AtLeastNOperands<1>::Impl> {
public:
using Op::Op;
/// Builds an affine load op with the specified map and operands.
static void build(OpBuilder &builder, OperationState &result, AffineMap map,
ValueRange operands);
/// Builds an affine load op with an identity map and operands.
static void build(OpBuilder &builder, OperationState &result, Value memref,
ValueRange indices = {});
/// Builds an affine load op with the specified map and its operands.
static void build(OpBuilder &builder, OperationState &result, Value memref,
AffineMap map, ValueRange mapOperands);
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 0; }
/// Get memref operand.
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
MemRefType getMemRefType() {
return getMemRef().getType().cast<MemRefType>();
}
/// Get affine map operands.
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
static StringRef getMapAttrName() { return "map"; }
static StringRef getOperationName() { return "affine.load"; }
// Hooks to customize behavior of this op.
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
OpFoldResult fold(ArrayRef<Attribute> operands);
};
/// The "affine.store" op writes an element to a memref, where the index
/// for each memref dimension is an affine expression of loop induction
/// variables and symbols. The 'affine.store' op stores a new value which is the
/// same type as the elements of the memref. An affine expression of loop IVs
/// and symbols must be specified for each dimension of the memref. The keyword
/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
//
// Example 1:
//
// affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
//
// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
//
// affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)]
// : memref<100x100xf32>
//
class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
OpTrait::AtLeastNOperands<1>::Impl> {
public:
using Op::Op;
/// Builds an affine store operation with the provided indices (identity map).
static void build(OpBuilder &builder, OperationState &result,
Value valueToStore, Value memref, ValueRange indices);
/// Builds an affine store operation with the specified map and its operands.
static void build(OpBuilder &builder, OperationState &result,
Value valueToStore, Value memref, AffineMap map,
ValueRange mapOperands);
/// Get value to be stored by store operation.
Value getValueToStore() { return getOperand(0); }
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 1; }
/// Get memref operand.
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
MemRefType getMemRefType() {
return getMemRef().getType().cast<MemRefType>();
}
/// Get affine map operands.
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
static StringRef getMapAttrName() { return "map"; }
static StringRef getOperationName() { return "affine.store"; }
// Hooks to customize behavior of this op.
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
};
/// Returns true if the given Value can be used as a dimension id.
bool isValidDim(Value value);

View File

@ -360,6 +360,80 @@ def AffineIfOp : Affine_Op<"if",
let hasFolder = 1;
}
def AffineLoadOp : Affine_Op<"load", []> {
let summary = "affine load operation";
let description = [{
The "affine.load" op reads an element from a memref, where the index
for each memref dimension is an affine expression of loop induction
variables and symbols. The output of 'affine.load' is a new value with the
same type as the elements of the memref. An affine expression of loop IVs
and symbols must be specified for each dimension of the memref. The keyword
'symbol' can be used to indicate SSA identifiers which are symbolic.
Example 1:
```mlir
%1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
```
Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
```mlir
%1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>
```
}];
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
Variadic<Index>:$indices);
let results = (outs AnyType:$result);
let builders = [
/// Builds an affine load op with the specified map and operands.
OpBuilder<"OpBuilder &builder, OperationState &result, AffineMap map, "
"ValueRange operands">,
/// Builds an affine load op with an identity map and operands.
OpBuilder<"OpBuilder &builder, OperationState &result, Value memref, "
"ValueRange indices = {}">,
/// Builds an affine load op with the specified map and its operands.
OpBuilder<"OpBuilder &builder, OperationState &result, Value memref, "
"AffineMap map, ValueRange mapOperands">
];
let extraClassDeclaration = [{
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 0; }
/// Get memref operand.
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
MemRefType getMemRefType() {
return getMemRef().getType().cast<MemRefType>();
}
/// Get affine map operands.
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
static StringRef getMapAttrName() { return "map"; }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
}
class AffineMinMaxOpBase<string mnemonic, list<OpTrait> traits = []> :
Op<Affine_Dialect, mnemonic, traits> {
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
@ -575,6 +649,81 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
let hasFolder = 1;
}
def AffineStoreOp : Affine_Op<"store", []> {
let summary = "affine store operation";
let description = [{
The "affine.store" op writes an element to a memref, where the index
for each memref dimension is an affine expression of loop induction
variables and symbols. The 'affine.store' op stores a new value which is the
same type as the elements of the memref. An affine expression of loop IVs
and symbols must be specified for each dimension of the memref. The keyword
'symbol' can be used to indicate SSA identifiers which are symbolic.
Example 1:
```mlir
affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
```
Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
```mlir
affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>
```
}];
let arguments = (ins AnyType:$value,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
Variadic<Index>:$indices);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
"Value valueToStore, Value memref, ValueRange indices">,
OpBuilder<"OpBuilder &builder, OperationState &result, "
"Value valueToStore, Value memref, AffineMap map, "
"ValueRange mapOperands">
];
let extraClassDeclaration = [{
/// Get value to be stored by store operation.
Value getValueToStore() { return getOperand(0); }
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 1; }
/// Get memref operand.
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
MemRefType getMemRefType() {
return getMemRef().getType().cast<MemRefType>();
}
/// Get affine map operands.
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
static StringRef getMapAttrName() { return "map"; }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def AffineTerminatorOp :
Affine_Op<"terminator", [NoSideEffect, Terminator]> {
let summary = "affine terminator operation";

View File

@ -69,7 +69,7 @@ struct AffineInlinerInterface : public DialectInlinerInterface {
AffineDialect::AffineDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<AffineDmaStartOp, AffineDmaWaitOp, AffineLoadOp, AffineStoreOp,
addOperations<AffineDmaStartOp, AffineDmaWaitOp,
#define GET_OP_LIST
#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
>();
@ -1765,7 +1765,7 @@ void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, memref, map, indices);
}
ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
ParseResult parseAffineLoadOp(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
auto indexTy = builder.getIndexType();
@ -1775,7 +1775,8 @@ ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 1> mapOperands;
return failure(
parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
AffineLoadOp::getMapAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
@ -1784,38 +1785,40 @@ ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(type.getElementType(), result.types));
}
void AffineLoadOp::print(OpAsmPrinter &p) {
p << "affine.load " << getMemRef() << '[';
if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
void print(OpAsmPrinter &p, AffineLoadOp op) {
p << "affine.load " << op.getMemRef() << '[';
if (AffineMapAttr mapAttr =
op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
p << ']';
p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
p << " : " << getMemRefType();
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
p << " : " << op.getMemRefType();
}
LogicalResult AffineLoadOp::verify() {
if (getType() != getMemRefType().getElementType())
return emitOpError("result type must match element type of memref");
LogicalResult verify(AffineLoadOp op) {
if (op.getType() != op.getMemRefType().getElementType())
return op.emitOpError("result type must match element type of memref");
auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
if (mapAttr) {
AffineMap map = getAttrOfType<AffineMapAttr>(getMapAttrName()).getValue();
if (map.getNumResults() != getMemRefType().getRank())
return emitOpError("affine.load affine map num results must equal"
" memref rank");
if (map.getNumInputs() != getNumOperands() - 1)
return emitOpError("expects as many subscripts as affine map inputs");
AffineMap map =
op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()).getValue();
if (map.getNumResults() != op.getMemRefType().getRank())
return op.emitOpError("affine.load affine map num results must equal"
" memref rank");
if (map.getNumInputs() != op.getNumOperands() - 1)
return op.emitOpError("expects as many subscripts as affine map inputs");
} else {
if (getMemRefType().getRank() != getNumOperands() - 1)
return emitOpError(
if (op.getMemRefType().getRank() != op.getNumOperands() - 1)
return op.emitOpError(
"expects the number of subscripts to be equal to memref rank");
}
for (auto idx : getMapOperands()) {
for (auto idx : op.getMapOperands()) {
if (!idx.getType().isIndex())
return emitOpError("index to load must have 'index' type");
return op.emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("index must be a dimension or symbol identifier");
return op.emitOpError("index must be a dimension or symbol identifier");
}
return success();
}
@ -1859,7 +1862,7 @@ void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, valueToStore, memref, map, indices);
}
ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
ParseResult parseAffineStoreOp(OpAsmParser &parser, OperationState &result) {
auto indexTy = parser.getBuilder().getIndexType();
MemRefType type;
@ -1870,7 +1873,7 @@ ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
getMapAttrName(),
AffineStoreOp::getMapAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
@ -1880,40 +1883,42 @@ ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
parser.resolveOperands(mapOperands, indexTy, result.operands));
}
void AffineStoreOp::print(OpAsmPrinter &p) {
p << "affine.store " << getValueToStore();
p << ", " << getMemRef() << '[';
if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
void print(OpAsmPrinter &p, AffineStoreOp op) {
p << "affine.store " << op.getValueToStore();
p << ", " << op.getMemRef() << '[';
if (AffineMapAttr mapAttr =
op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
p << ']';
p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
p << " : " << getMemRefType();
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
p << " : " << op.getMemRefType();
}
LogicalResult AffineStoreOp::verify() {
LogicalResult verify(AffineStoreOp op) {
// First operand must have same type as memref element type.
if (getValueToStore().getType() != getMemRefType().getElementType())
return emitOpError("first operand must have same type memref element type");
if (op.getValueToStore().getType() != op.getMemRefType().getElementType())
return op.emitOpError(
"first operand must have same type memref element type");
auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
if (mapAttr) {
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != getMemRefType().getRank())
return emitOpError("affine.store affine map num results must equal"
" memref rank");
if (map.getNumInputs() != getNumOperands() - 2)
return emitOpError("expects as many subscripts as affine map inputs");
if (map.getNumResults() != op.getMemRefType().getRank())
return op.emitOpError("affine.store affine map num results must equal"
" memref rank");
if (map.getNumInputs() != op.getNumOperands() - 2)
return op.emitOpError("expects as many subscripts as affine map inputs");
} else {
if (getMemRefType().getRank() != getNumOperands() - 2)
return emitOpError(
if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
return op.emitOpError(
"expects the number of subscripts to be equal to memref rank");
}
for (auto idx : getMapOperands()) {
for (auto idx : op.getMapOperands()) {
if (!idx.getType().isIndex())
return emitOpError("index to store must have 'index' type");
return op.emitOpError("index to store must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("index must be a dimension or symbol identifier");
return op.emitOpError("index must be a dimension or symbol identifier");
}
return success();
}

View File

@ -481,19 +481,19 @@ func @canonicalize_bounds(%M : index, %N : index) {
// CHECK-LABEL: @compose_into_affine_load_store
func @compose_into_affine_load_store(%A : memref<1024xf32>, %u : index) {
%cf1 = constant 1.0 : f32
// CHECK: affine.for %[[IV:.*]] = 0 to 1024
affine.for %i = 0 to 1024 {
// Make sure the unused operand (%u below) gets dropped as well.
%idx = affine.apply affine_map<(d0, d1) -> (d0 + 1)> (%i, %u)
affine.load %A[%idx] : memref<1024xf32>
affine.store %cf1, %A[%idx] : memref<1024xf32>
%0 = affine.load %A[%idx] : memref<1024xf32>
affine.store %0, %A[%idx] : memref<1024xf32>
// CHECK-NEXT: affine.load %{{.*}}[%[[IV]] + 1]
// CHECK-NEXT: affine.store %cst, %{{.*}}[%[[IV]] + 1]
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[IV]] + 1]
// Map remains the same, but operand changes on composition.
%copy = affine.apply affine_map<(d0) -> (d0)> (%i)
affine.load %A[%copy] : memref<1024xf32>
%1 = affine.load %A[%copy] : memref<1024xf32>
"prevent.dce"(%1) : (f32) -> ()
// CHECK-NEXT: affine.load %{{.*}}[%[[IV]]]
}
return

View File

@ -6,35 +6,35 @@
func @if_else_imperfect(%A : memref<100xi32>, %B : memref<100xi32>, %v : i32) {
// CHECK: %[[A:.*]]: memref<100xi32>, %[[B:.*]]: memref
affine.for %i = 0 to 100 {
affine.load %A[%i] : memref<100xi32>
affine.store %v, %A[%i] : memref<100xi32>
affine.for %j = 0 to 100 {
affine.load %A[%j] : memref<100xi32>
affine.store %v, %A[%j] : memref<100xi32>
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
affine.load %B[%j] : memref<100xi32>
affine.store %v, %B[%j] : memref<100xi32>
}
call @external() : () -> ()
}
affine.load %A[%i] : memref<100xi32>
affine.store %v, %A[%i] : memref<100xi32>
}
return
}
func @external()
// CHECK: affine.for %[[I:.*]] = 0 to 100 {
// CHECK-NEXT: affine.load %[[A]][%[[I]]]
// CHECK-NEXT: affine.store %{{.*}}, %[[A]][%[[I]]]
// CHECK-NEXT: affine.if #[[SET]](%[[I]]) {
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 100 {
// CHECK-NEXT: affine.load %[[A]][%[[J]]]
// CHECK-NEXT: affine.load %[[B]][%[[J]]]
// CHECK-NEXT: affine.store %{{.*}}, %[[A]][%[[J]]]
// CHECK-NEXT: affine.store %{{.*}}, %[[B]][%[[J]]]
// CHECK-NEXT: call
// CHECK-NEXT: }
// CHECK-NEXT: } else {
// CHECK-NEXT: affine.for %[[JJ:.*]] = 0 to 100 {
// CHECK-NEXT: affine.load %[[A]][%[[JJ]]]
// CHECK-NEXT: affine.store %{{.*}}, %[[A]][%[[J]]]
// CHECK-NEXT: call
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.load %[[A]][%[[I]]]
// CHECK-NEXT: affine.store %{{.*}}, %[[A]][%[[I]]]
// CHECK-NEXT: }
// CHECK-NEXT: return
@ -51,7 +51,7 @@ func @if_then_perfect(%A : memref<100xi32>, %v : i32) {
affine.for %j = 0 to 100 {
affine.for %k = 0 to 100 {
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
affine.load %A[%i] : memref<100xi32>
affine.store %v, %A[%i] : memref<100xi32>
}
}
}
@ -72,10 +72,10 @@ func @if_else_perfect(%A : memref<100xi32>, %v : i32) {
affine.for %k = 0 to 100 {
call @foo() : () -> ()
affine.if affine_set<(d0, d1) : (d0 - 2 >= 0, -d1 + 80 >= 0)>(%i, %j) {
affine.load %A[%i] : memref<100xi32>
affine.store %v, %A[%i] : memref<100xi32>
call @abc() : () -> ()
} else {
affine.load %A[%i + 1] : memref<100xi32>
affine.store %v, %A[%i + 1] : memref<100xi32>
call @xyz() : () -> ()
}
call @bar() : () -> ()
@ -89,14 +89,14 @@ func @if_else_perfect(%A : memref<100xi32>, %v : i32) {
// CHECK-NEXT: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: call @foo
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}]
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}]
// CHECK-NEXT: call @abc
// CHECK-NEXT: call @bar
// CHECK-NEXT: }
// CHECK-NEXT: else
// CHECK-NEXT: affine.for
// CHECK-NEXT: call @foo
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}} + 1]
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} + 1]
// CHECK-NEXT: call @xyz
// CHECK-NEXT: call @bar
// CHECK-NEXT: }
@ -105,23 +105,23 @@ func @if_else_perfect(%A : memref<100xi32>, %v : i32) {
// CHECK-NEXT: }
// CHECK-LABEL: func @if_then_imperfect
func @if_then_imperfect(%A : memref<100xi32>, %N : index) {
func @if_then_imperfect(%A : memref<100xi32>, %N : index, %v: i32) {
affine.for %i = 0 to 100 {
affine.load %A[0] : memref<100xi32>
affine.store %v, %A[0] : memref<100xi32>
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%N) {
affine.load %A[%i] : memref<100xi32>
affine.store %v, %A[%i] : memref<100xi32>
}
}
return
}
// CHECK: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: affine.load
// CHECK-NEXT: affine.store
// CHECK-NEXT: affine.store
// CHECK-NEXT: }
// CHECK-NEXT: } else {
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: affine.store
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
@ -182,21 +182,21 @@ func @handle_dead_if(%N : index) {
#set0 = affine_set<(d0, d1)[s0, s1] : (d0 * -16 + s0 - 16 >= 0, d1 * -3 + s1 - 3 >= 0)>
// CHECK-LABEL: func @perfect_if_else
func @perfect_if_else(%arg0 : memref<?x?xf64>, %arg1 : memref<?x?xf64>, %arg4 : index,
%arg5 : index, %arg6 : index, %sym : index) {
func @perfect_if_else(%arg0 : memref<?x?xf64>, %arg1 : memref<?x?xf64>, %v : f64,
%arg4 : index, %arg5 : index, %arg6 : index, %sym : index) {
affine.for %arg7 = #lb0(%arg5) to min #ub0(%arg5)[%sym] {
affine.parallel (%i0, %j0) = (0, 0) to (symbol(%sym), 100) step (10, 10) {
affine.for %arg8 = #lb1(%arg4) to min #ub1(%arg4)[%sym] {
affine.if #set0(%arg6, %arg7)[%sym, %sym] {
affine.for %arg9 = #flb0(%arg6) to #fub0(%arg6) {
affine.for %arg10 = #flb1(%arg7) to #fub1(%arg7) {
affine.load %arg0[0, 0] : memref<?x?xf64>
affine.store %v, %arg0[0, 0] : memref<?x?xf64>
}
}
} else {
affine.for %arg9 = #lb0(%arg6) to min #pub0(%arg6)[%sym] {
affine.for %arg10 = #lb1(%arg7) to min #pub1(%arg7)[%sym] {
affine.load %arg0[0, 0] : memref<?x?xf64>
affine.store %v, %arg0[0, 0] : memref<?x?xf64>
}
}
}
@ -212,7 +212,7 @@ func @perfect_if_else(%arg0 : memref<?x?xf64>, %arg1 : memref<?x?xf64>, %arg4 :
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: affine.store
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
@ -222,7 +222,7 @@ func @perfect_if_else(%arg0 : memref<?x?xf64>, %arg1 : memref<?x?xf64>, %arg4 :
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: affine.store
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

View File

@ -310,26 +310,26 @@ func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
}
// CHECK-LABEL: func @memref_cast_folding
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> (f32, f32) {
%0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
// CHECK-NEXT: %c0 = constant 0 : index
%c0 = constant 0 : index
%dim = dim %1, 0 : memref<? x f32>
%dim = dim %0, 0 : memref<? x f32>
// CHECK-NEXT: affine.load %arg0[3]
affine.load %1[%dim - 1] : memref<?xf32>
%1 = affine.load %0[%dim - 1] : memref<?xf32>
// CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32>
store %arg1, %1[%c0] : memref<?xf32>
store %arg1, %0[%c0] : memref<?xf32>
// CHECK-NEXT: %{{.*}} = load %arg0[%c0] : memref<4xf32>
%0 = load %1[%c0] : memref<?xf32>
%2 = load %0[%c0] : memref<?xf32>
// CHECK-NEXT: dealloc %arg0 : memref<4xf32>
dealloc %1: memref<?xf32>
dealloc %0: memref<?xf32>
// CHECK-NEXT: return %{{.*}}
return %0 : f32
return %1, %2 : f32, f32
}
// CHECK-LABEL: func @alloc_const_fold
@ -869,7 +869,8 @@ func @remove_dead_else(%M : memref<100 x i32>) {
affine.load %M[%i] : memref<100xi32>
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
affine.for %j = 0 to 100 {
affine.load %M[%j] : memref<100xi32>
%1 = affine.load %M[%j] : memref<100xi32>
"prevent.dce"(%1) : (i32) -> ()
}
} else {
// Nothing
@ -881,9 +882,9 @@ func @remove_dead_else(%M : memref<100 x i32>) {
// CHECK: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: "prevent.dce"
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.load
// -----

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-loop-fusion -test-loop-fusion-transformation -split-input-file -canonicalize | FileCheck %s
// RUN: mlir-opt %s -allow-unregistered-dialect -test-loop-fusion -test-loop-fusion-transformation -split-input-file -canonicalize | FileCheck %s
// CHECK-LABEL: func @slice_depth1_loop_nest() {
func @slice_depth1_loop_nest() {
@ -9,10 +9,12 @@ func @slice_depth1_loop_nest() {
}
affine.for %i1 = 0 to 5 {
%1 = affine.load %0[%i1] : memref<100xf32>
"prevent.dce"(%1) : (f32) -> ()
}
// CHECK: affine.for %[[IV0:.*]] = 0 to 5 {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[IV0]]] : memref<100xf32>
// CHECK-NEXT: affine.load %{{.*}}[%[[IV0]]] : memref<100xf32>
// CHECK-NEXT: "prevent.dce"(%1) : (f32) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
return
@ -74,15 +76,16 @@ func @should_fuse_avoiding_dependence_cycle() {
// 3) loop1 -> loop2 on memref '%{{.*}}'
affine.for %i0 = 0 to 10 {
%v0 = affine.load %a[%i0] : memref<10xf32>
affine.store %cf7, %b[%i0] : memref<10xf32>
affine.store %v0, %b[%i0] : memref<10xf32>
}
affine.for %i1 = 0 to 10 {
affine.store %cf7, %a[%i1] : memref<10xf32>
%v1 = affine.load %c[%i1] : memref<10xf32>
"prevent.dce"(%v1) : (f32) -> ()
}
affine.for %i2 = 0 to 10 {
%v2 = affine.load %b[%i2] : memref<10xf32>
affine.store %cf7, %c[%i2] : memref<10xf32>
affine.store %v2, %c[%i2] : memref<10xf32>
}
// Fusing loop first loop into last would create a cycle:
// {1} <--> {0, 2}
@ -97,6 +100,7 @@ func @should_fuse_avoiding_dependence_cycle() {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: "prevent.dce"
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }

View File

@ -1,11 +1,12 @@
// RUN: mlir-opt -simplify-affine-structures %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -simplify-affine-structures %s | FileCheck %s
// CHECK-LABEL: func @permute()
func @permute() {
%A = alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
affine.for %i = 0 to 64 {
affine.for %j = 0 to 256 {
affine.load %A[%i, %j] : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
%1 = affine.load %A[%i, %j] : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
"prevent.dce"(%1) : (f32) -> ()
}
}
dealloc %A : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
@ -17,6 +18,7 @@ func @permute() {
// CHECK-NEXT: affine.for %[[I:arg[0-9]+]] = 0 to 64 {
// CHECK-NEXT: affine.for %[[J:arg[0-9]+]] = 0 to 256 {
// CHECK-NEXT: affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32>
// CHECK-NEXT: "prevent.dce"
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: dealloc [[MEM]]
@ -29,7 +31,8 @@ func @shift(%idx : index) {
// CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
affine.for %i = 0 to 64 {
affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
%1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
"prevent.dce"(%1) : (f32) -> ()
// CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
}
return
@ -45,8 +48,9 @@ func @high_dim_permute() {
affine.for %j = 0 to 128 {
// CHECK: %[[K:arg[0-9]+]]
affine.for %k = 0 to 256 {
affine.load %A[%i, %j, %k] : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
%1 = affine.load %A[%i, %j, %k] : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
// CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : memref<256x64x128xf32>
"prevent.dce"(%1) : (f32) -> ()
}
}
}
@ -66,7 +70,8 @@ func @data_tiling(%idx : index) {
// CHECK: alloc() : memref<8x32x8x16xf32>
%A = alloc() : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>>
// CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16]
affine.load %A[%idx, %idx] : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>>
%1 = affine.load %A[%idx, %idx] : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>>
"prevent.dce"(%1) : (f32) -> ()
return
}
@ -79,7 +84,8 @@ func @strided() {
// CHECK: affine.for %[[IV1:.*]] =
affine.for %j = 0 to 128 {
// CHECK: affine.load %{{.*}}[%[[IV0]] * 2, %[[IV1]] * 4] : memref<127x509xf32>
affine.load %A[%i, %j] : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>>
%1 = affine.load %A[%i, %j] : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>>
"prevent.dce"(%1) : (f32) -> ()
}
}
return
@ -94,7 +100,8 @@ func @strided_cumulative() {
// CHECK: affine.for %[[IV1:.*]] =
affine.for %j = 0 to 5 {
// CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : memref<72xf32>
affine.load %A[%i, %j] : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>>
%1 = affine.load %A[%i, %j] : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>>
"prevent.dce"(%1) : (f32) -> ()
}
}
return
@ -109,7 +116,8 @@ func @symbolic_operands(%s : index) {
affine.for %i = 0 to 10 {
affine.for %j = 0 to 10 {
// CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32>
affine.load %A[%i, %j] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>>
%1 = affine.load %A[%i, %j] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>>
"prevent.dce"(%1) : (f32) -> ()
}
}
return