mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-15 12:39:19 +00:00
[mlir] support max/min lower/upper bounds in affine.parallel
This enables to express more complex parallel loops in the affine framework, for example, in cases of tiling by sizes not dividing loop trip counts perfectly or inner wavefront parallelism, among others. One can't use affine.max/min and supply values to the nested loop bounds since the results of such affine.max/min operations aren't valid symbols. Making them valid symbols isn't an option since they would introduce selection trees into memref subscript arithmetic as an unintended and undesired consequence. Also add support for converting such loops to SCF. Drop some API that isn't used in the core repo from AffineParallelOp since its semantics becomes ambiguous in presence of max/min bounds. Loop normalization is currently unavailable for such loops. Depends On D101171 Reviewed By: bondhugula Differential Revision: https://reviews.llvm.org/D101172
This commit is contained in:
parent
545fa37834
commit
6841e6afba
@ -613,7 +613,13 @@ def AffineParallelOp : Affine_Op<"parallel",
|
|||||||
The lower and upper bounds of a parallel operation are represented as an
|
The lower and upper bounds of a parallel operation are represented as an
|
||||||
application of an affine mapping to a list of SSA values passed to the map.
|
application of an affine mapping to a list of SSA values passed to the map.
|
||||||
The same restrictions hold for these SSA values as for all bindings of SSA
|
The same restrictions hold for these SSA values as for all bindings of SSA
|
||||||
values to dimensions and symbols.
|
values to dimensions and symbols. The list of expressions in each map is
|
||||||
|
interpreted according to the respective bounds group attribute. If a single
|
||||||
|
expression belongs to the group, then the result of this expression is taken
|
||||||
|
as a lower(upper) bound of the corresponding loop induction variable. If
|
||||||
|
multiple expressions belong to the group, then the lower(upper) bound is the
|
||||||
|
max(min) of these values obtained from these expressions. The loop band has
|
||||||
|
as many loops as elements in the group bounds attributes.
|
||||||
|
|
||||||
Each value yielded by affine.yield will be accumulated/reduced via one of
|
Each value yielded by affine.yield will be accumulated/reduced via one of
|
||||||
the reduction methods defined in the AtomicRMWKind enum. The order of
|
the reduction methods defined in the AtomicRMWKind enum. The order of
|
||||||
@ -644,12 +650,25 @@ def AffineParallelOp : Affine_Op<"parallel",
|
|||||||
return %O
|
return %O
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Example (tiling by potentially imperfectly dividing sizes):
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
affine.parallel (%ii, %jj) = (0, 0) to (%N, %M) step (32, 32) {
|
||||||
|
affine.parallel (%i, %j) = (%ii, %jj)
|
||||||
|
to (min(%ii + 32, %N), min(%jj + 32, %M)) {
|
||||||
|
call @f(%i, %j) : (index, index) -> ()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TypedArrayAttrBase<AtomicRMWKindAttr, "Reduction ops">:$reductions,
|
TypedArrayAttrBase<AtomicRMWKindAttr, "Reduction ops">:$reductions,
|
||||||
AffineMapAttr:$lowerBoundsMap,
|
AffineMapAttr:$lowerBoundsMap,
|
||||||
|
I32ElementsAttr:$lowerBoundsGroups,
|
||||||
AffineMapAttr:$upperBoundsMap,
|
AffineMapAttr:$upperBoundsMap,
|
||||||
|
I32ElementsAttr:$upperBoundsGroups,
|
||||||
I64ArrayAttr:$steps,
|
I64ArrayAttr:$steps,
|
||||||
Variadic<Index>:$mapOperands);
|
Variadic<Index>:$mapOperands);
|
||||||
let results = (outs Variadic<AnyType>:$results);
|
let results = (outs Variadic<AnyType>:$results);
|
||||||
@ -659,11 +678,8 @@ def AffineParallelOp : Affine_Op<"parallel",
|
|||||||
OpBuilder<(ins "TypeRange":$resultTypes,
|
OpBuilder<(ins "TypeRange":$resultTypes,
|
||||||
"ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
|
"ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
|
||||||
OpBuilder<(ins "TypeRange":$resultTypes,
|
OpBuilder<(ins "TypeRange":$resultTypes,
|
||||||
"ArrayRef<AtomicRMWKind>":$reductions, "AffineMap":$lbMap,
|
"ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
|
||||||
"ValueRange":$lbArgs, "AffineMap":$ubMap, "ValueRange":$ubArgs)>,
|
"ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
|
||||||
OpBuilder<(ins "TypeRange":$resultTypes,
|
|
||||||
"ArrayRef<AtomicRMWKind>":$reductions, "AffineMap":$lbMap,
|
|
||||||
"ValueRange":$lbArgs, "AffineMap":$ubMap, "ValueRange":$ubArgs,
|
|
||||||
"ArrayRef<int64_t>":$steps)>
|
"ArrayRef<int64_t>":$steps)>
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -671,8 +687,6 @@ def AffineParallelOp : Affine_Op<"parallel",
|
|||||||
/// Get the number of dimensions.
|
/// Get the number of dimensions.
|
||||||
unsigned getNumDims();
|
unsigned getNumDims();
|
||||||
|
|
||||||
AffineValueMap getRangesValueMap();
|
|
||||||
|
|
||||||
/// Get ranges as constants, may fail in dynamic case.
|
/// Get ranges as constants, may fail in dynamic case.
|
||||||
Optional<SmallVector<int64_t, 8>> getConstantRanges();
|
Optional<SmallVector<int64_t, 8>> getConstantRanges();
|
||||||
|
|
||||||
@ -682,23 +696,45 @@ def AffineParallelOp : Affine_Op<"parallel",
|
|||||||
return getBody()->getArguments();
|
return getBody()->getArguments();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns elements of the loop lower bound.
|
||||||
|
AffineMap getLowerBoundMap(unsigned pos);
|
||||||
operand_range getLowerBoundsOperands();
|
operand_range getLowerBoundsOperands();
|
||||||
AffineValueMap getLowerBoundsValueMap();
|
AffineValueMap getLowerBoundsValueMap();
|
||||||
|
|
||||||
|
/// Sets elements of the loop lower bound.
|
||||||
void setLowerBounds(ValueRange operands, AffineMap map);
|
void setLowerBounds(ValueRange operands, AffineMap map);
|
||||||
void setLowerBoundsMap(AffineMap map);
|
void setLowerBoundsMap(AffineMap map);
|
||||||
|
|
||||||
|
/// Returns elements of the loop upper bound.
|
||||||
|
AffineMap getUpperBoundMap(unsigned pos);
|
||||||
operand_range getUpperBoundsOperands();
|
operand_range getUpperBoundsOperands();
|
||||||
AffineValueMap getUpperBoundsValueMap();
|
AffineValueMap getUpperBoundsValueMap();
|
||||||
|
|
||||||
|
/// Sets elements fo the loop upper bound.
|
||||||
void setUpperBounds(ValueRange operands, AffineMap map);
|
void setUpperBounds(ValueRange operands, AffineMap map);
|
||||||
void setUpperBoundsMap(AffineMap map);
|
void setUpperBoundsMap(AffineMap map);
|
||||||
|
|
||||||
SmallVector<int64_t, 8> getSteps();
|
SmallVector<int64_t, 8> getSteps();
|
||||||
void setSteps(ArrayRef<int64_t> newSteps);
|
void setSteps(ArrayRef<int64_t> newSteps);
|
||||||
|
|
||||||
|
/// Returns attribute names to use in op construction. Not expected to be
|
||||||
|
/// used directly.
|
||||||
static StringRef getReductionsAttrName() { return "reductions"; }
|
static StringRef getReductionsAttrName() { return "reductions"; }
|
||||||
static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; }
|
static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; }
|
||||||
|
static StringRef getLowerBoundsGroupsAttrName() {
|
||||||
|
return "lowerBoundsGroups";
|
||||||
|
}
|
||||||
static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; }
|
static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; }
|
||||||
|
static StringRef getUpperBoundsGroupsAttrName() {
|
||||||
|
return "upperBoundsGroups";
|
||||||
|
}
|
||||||
static StringRef getStepsAttrName() { return "steps"; }
|
static StringRef getStepsAttrName() { return "steps"; }
|
||||||
|
|
||||||
|
/// Returns `true` if the loop bounds have min/max expressions.
|
||||||
|
bool hasMinMaxBounds() {
|
||||||
|
return lowerBoundsMap().getNumResults() != getNumDims() ||
|
||||||
|
upperBoundsMap().getNumResults() != getNumDims();
|
||||||
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
@ -262,6 +262,9 @@ public:
|
|||||||
/// Returns the map consisting of the `resultPos` subset.
|
/// Returns the map consisting of the `resultPos` subset.
|
||||||
AffineMap getSubMap(ArrayRef<unsigned> resultPos) const;
|
AffineMap getSubMap(ArrayRef<unsigned> resultPos) const;
|
||||||
|
|
||||||
|
/// Returns the map consisting of `length` expressions starting from `start`.
|
||||||
|
AffineMap getSliceMap(unsigned start, unsigned length) const;
|
||||||
|
|
||||||
/// Returns the map consisting of the most major `numResults` results.
|
/// Returns the map consisting of the most major `numResults` results.
|
||||||
/// Returns the null AffineMap if `numResults` == 0.
|
/// Returns the null AffineMap if `numResults` == 0.
|
||||||
/// Returns `*this` if `numResults` >= `this->getNumResults()`.
|
/// Returns `*this` if `numResults` >= `this->getNumResults()`.
|
||||||
|
@ -113,6 +113,13 @@ public:
|
|||||||
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
||||||
ValueRange operands) = 0;
|
ValueRange operands) = 0;
|
||||||
|
|
||||||
|
/// Prints an affine expression of SSA ids with SSA id names used instead of
|
||||||
|
/// dims and symbols.
|
||||||
|
/// Operand values must come from single-result sources, and be valid
|
||||||
|
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
|
||||||
|
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
|
||||||
|
ValueRange symOperands) = 0;
|
||||||
|
|
||||||
/// Print an optional arrow followed by a type list.
|
/// Print an optional arrow followed by a type list.
|
||||||
template <typename TypeRange>
|
template <typename TypeRange>
|
||||||
void printOptionalArrowTypeList(TypeRange &&types) {
|
void printOptionalArrowTypeList(TypeRange &&types) {
|
||||||
@ -680,6 +687,14 @@ public:
|
|||||||
StringRef attrName, NamedAttrList &attrs,
|
StringRef attrName, NamedAttrList &attrs,
|
||||||
Delimiter delimiter = Delimiter::Square) = 0;
|
Delimiter delimiter = Delimiter::Square) = 0;
|
||||||
|
|
||||||
|
/// Parses an affine expression where dims and symbols are SSA operands.
|
||||||
|
/// Operand values must come from single-result sources, and be valid
|
||||||
|
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
|
||||||
|
virtual ParseResult
|
||||||
|
parseAffineExprOfSSAIds(SmallVectorImpl<OperandType> &dimOperands,
|
||||||
|
SmallVectorImpl<OperandType> &symbOperands,
|
||||||
|
AffineExpr &expr) = 0;
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Region Parsing
|
// Region Parsing
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
@ -423,20 +423,28 @@ public:
|
|||||||
SmallVector<Value, 8> upperBoundTuple;
|
SmallVector<Value, 8> upperBoundTuple;
|
||||||
SmallVector<Value, 8> lowerBoundTuple;
|
SmallVector<Value, 8> lowerBoundTuple;
|
||||||
SmallVector<Value, 8> identityVals;
|
SmallVector<Value, 8> identityVals;
|
||||||
// Finding lower and upper bound by expanding the map expression.
|
// Emit IR computing the lower and upper bound by expanding the map
|
||||||
// Checking if expandAffineMap is not giving NULL.
|
// expression.
|
||||||
Optional<SmallVector<Value, 8>> lowerBound = expandAffineMap(
|
lowerBoundTuple.reserve(op.getNumDims());
|
||||||
rewriter, loc, op.lowerBoundsMap(), op.getLowerBoundsOperands());
|
upperBoundTuple.reserve(op.getNumDims());
|
||||||
Optional<SmallVector<Value, 8>> upperBound = expandAffineMap(
|
for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
|
||||||
rewriter, loc, op.upperBoundsMap(), op.getUpperBoundsOperands());
|
Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i),
|
||||||
if (!lowerBound || !upperBound)
|
op.getLowerBoundsOperands());
|
||||||
return failure();
|
if (!lower)
|
||||||
upperBoundTuple = *upperBound;
|
return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds");
|
||||||
lowerBoundTuple = *lowerBound;
|
lowerBoundTuple.push_back(lower);
|
||||||
|
|
||||||
|
Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i),
|
||||||
|
op.getUpperBoundsOperands());
|
||||||
|
if (!upper)
|
||||||
|
return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
|
||||||
|
upperBoundTuple.push_back(upper);
|
||||||
|
}
|
||||||
steps.reserve(op.steps().size());
|
steps.reserve(op.steps().size());
|
||||||
for (Attribute step : op.steps())
|
for (Attribute step : op.steps())
|
||||||
steps.push_back(rewriter.create<ConstantIndexOp>(
|
steps.push_back(rewriter.create<ConstantIndexOp>(
|
||||||
loc, step.cast<IntegerAttr>().getInt()));
|
loc, step.cast<IntegerAttr>().getInt()));
|
||||||
|
|
||||||
// Get the terminator op.
|
// Get the terminator op.
|
||||||
Operation *affineParOpTerminator = op.getBody()->getTerminator();
|
Operation *affineParOpTerminator = op.getBody()->getTerminator();
|
||||||
scf::ParallelOp parOp;
|
scf::ParallelOp parOp;
|
||||||
|
@ -2604,45 +2604,46 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
|||||||
TypeRange resultTypes,
|
TypeRange resultTypes,
|
||||||
ArrayRef<AtomicRMWKind> reductions,
|
ArrayRef<AtomicRMWKind> reductions,
|
||||||
ArrayRef<int64_t> ranges) {
|
ArrayRef<int64_t> ranges) {
|
||||||
SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
|
SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
|
||||||
builder.getAffineConstantExpr(0));
|
auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
|
||||||
auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext());
|
return builder.getConstantAffineMap(value);
|
||||||
SmallVector<AffineExpr, 8> ubExprs;
|
}));
|
||||||
for (int64_t range : ranges)
|
SmallVector<int64_t> steps(ranges.size(), 1);
|
||||||
ubExprs.push_back(builder.getAffineConstantExpr(range));
|
build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
|
||||||
auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext());
|
/*ubArgs=*/{}, steps);
|
||||||
build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap,
|
|
||||||
/*ubArgs=*/{});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
||||||
TypeRange resultTypes,
|
TypeRange resultTypes,
|
||||||
ArrayRef<AtomicRMWKind> reductions,
|
ArrayRef<AtomicRMWKind> reductions,
|
||||||
AffineMap lbMap, ValueRange lbArgs,
|
ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
|
||||||
AffineMap ubMap, ValueRange ubArgs) {
|
ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
|
||||||
auto numDims = lbMap.getNumResults();
|
|
||||||
// Verify that the dimensionality of both maps are the same.
|
|
||||||
assert(numDims == ubMap.getNumResults() &&
|
|
||||||
"num dims and num results mismatch");
|
|
||||||
// Make default step sizes of 1.
|
|
||||||
SmallVector<int64_t, 8> steps(numDims, 1);
|
|
||||||
build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs,
|
|
||||||
steps);
|
|
||||||
}
|
|
||||||
|
|
||||||
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
|
||||||
TypeRange resultTypes,
|
|
||||||
ArrayRef<AtomicRMWKind> reductions,
|
|
||||||
AffineMap lbMap, ValueRange lbArgs,
|
|
||||||
AffineMap ubMap, ValueRange ubArgs,
|
|
||||||
ArrayRef<int64_t> steps) {
|
ArrayRef<int64_t> steps) {
|
||||||
auto numDims = lbMap.getNumResults();
|
assert(!lbMaps.empty() && "expected the lower bound map to be non-empty");
|
||||||
// Verify that the dimensionality of the maps matches the number of steps.
|
assert(!ubMaps.empty() && "expected the upper bound map to be non-empty");
|
||||||
assert(numDims == ubMap.getNumResults() &&
|
assert(llvm::all_of(lbMaps,
|
||||||
"num dims and num results mismatch");
|
[lbMaps](AffineMap m) {
|
||||||
assert(numDims == steps.size() && "num dims and num steps mismatch");
|
return m.getNumDims() == lbMaps[0].getNumDims() &&
|
||||||
|
m.getNumSymbols() == lbMaps[0].getNumSymbols();
|
||||||
|
}) &&
|
||||||
|
"expected all lower bounds maps to have the same number of dimensions "
|
||||||
|
"and symbols");
|
||||||
|
assert(llvm::all_of(ubMaps,
|
||||||
|
[ubMaps](AffineMap m) {
|
||||||
|
return m.getNumDims() == ubMaps[0].getNumDims() &&
|
||||||
|
m.getNumSymbols() == ubMaps[0].getNumSymbols();
|
||||||
|
}) &&
|
||||||
|
"expected all upper bounds maps to have the same number of dimensions "
|
||||||
|
"and symbols");
|
||||||
|
assert(lbMaps[0].getNumInputs() == lbArgs.size() &&
|
||||||
|
"expected lower bound maps to have as many inputs as lower bound "
|
||||||
|
"operands");
|
||||||
|
assert(ubMaps[0].getNumInputs() == ubArgs.size() &&
|
||||||
|
"expected upper bound maps to have as many inputs as upper bound "
|
||||||
|
"operands");
|
||||||
|
|
||||||
result.addTypes(resultTypes);
|
result.addTypes(resultTypes);
|
||||||
|
|
||||||
// Convert the reductions to integer attributes.
|
// Convert the reductions to integer attributes.
|
||||||
SmallVector<Attribute, 4> reductionAttrs;
|
SmallVector<Attribute, 4> reductionAttrs;
|
||||||
for (AtomicRMWKind reduction : reductions)
|
for (AtomicRMWKind reduction : reductions)
|
||||||
@ -2650,16 +2651,42 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
|
|||||||
builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
|
builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
|
||||||
result.addAttribute(getReductionsAttrName(),
|
result.addAttribute(getReductionsAttrName(),
|
||||||
builder.getArrayAttr(reductionAttrs));
|
builder.getArrayAttr(reductionAttrs));
|
||||||
|
|
||||||
|
// Concatenates maps defined in the same input space (same dimensions and
|
||||||
|
// symbols), assumes there is at least one map.
|
||||||
|
auto concatMapsSameInput = [](ArrayRef<AffineMap> maps,
|
||||||
|
SmallVectorImpl<int32_t> &groups) {
|
||||||
|
SmallVector<AffineExpr> exprs;
|
||||||
|
groups.reserve(groups.size() + maps.size());
|
||||||
|
exprs.reserve(maps.size());
|
||||||
|
for (AffineMap m : maps) {
|
||||||
|
llvm::append_range(exprs, m.getResults());
|
||||||
|
groups.push_back(m.getNumResults());
|
||||||
|
}
|
||||||
|
assert(!maps.empty() && "expected a non-empty list of maps");
|
||||||
|
return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
|
||||||
|
maps[0].getContext());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set up the bounds.
|
||||||
|
SmallVector<int32_t> lbGroups, ubGroups;
|
||||||
|
AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
|
||||||
|
AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
|
||||||
result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
|
result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
|
||||||
|
result.addAttribute(getLowerBoundsGroupsAttrName(),
|
||||||
|
builder.getI32VectorAttr(lbGroups));
|
||||||
result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
|
result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
|
||||||
|
result.addAttribute(getUpperBoundsGroupsAttrName(),
|
||||||
|
builder.getI32VectorAttr(ubGroups));
|
||||||
result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
|
result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
|
||||||
result.addOperands(lbArgs);
|
result.addOperands(lbArgs);
|
||||||
result.addOperands(ubArgs);
|
result.addOperands(ubArgs);
|
||||||
|
|
||||||
// Create a region and a block for the body.
|
// Create a region and a block for the body.
|
||||||
auto *bodyRegion = result.addRegion();
|
auto *bodyRegion = result.addRegion();
|
||||||
auto *body = new Block();
|
auto *body = new Block();
|
||||||
// Add all the block arguments.
|
// Add all the block arguments.
|
||||||
for (unsigned i = 0; i < numDims; ++i)
|
for (unsigned i = 0, e = steps.size(); i < e; ++i)
|
||||||
body->addArgument(IndexType::get(builder.getContext()));
|
body->addArgument(IndexType::get(builder.getContext()));
|
||||||
bodyRegion->push_back(body);
|
bodyRegion->push_back(body);
|
||||||
if (resultTypes.empty())
|
if (resultTypes.empty())
|
||||||
@ -2688,6 +2715,22 @@ AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
|
|||||||
return getOperands().drop_front(lowerBoundsMap().getNumInputs());
|
return getOperands().drop_front(lowerBoundsMap().getNumInputs());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
|
||||||
|
unsigned start = 0;
|
||||||
|
for (unsigned i = 0; i < pos; ++i)
|
||||||
|
start += lowerBoundsGroups().getValue<int32_t>(i);
|
||||||
|
return lowerBoundsMap().getSliceMap(
|
||||||
|
start, lowerBoundsGroups().getValue<int32_t>(pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
|
||||||
|
unsigned start = 0;
|
||||||
|
for (unsigned i = 0; i < pos; ++i)
|
||||||
|
start += upperBoundsGroups().getValue<int32_t>(i);
|
||||||
|
return upperBoundsMap().getSliceMap(
|
||||||
|
start, upperBoundsGroups().getValue<int32_t>(pos));
|
||||||
|
}
|
||||||
|
|
||||||
AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
|
AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
|
||||||
return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
|
return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
|
||||||
}
|
}
|
||||||
@ -2696,17 +2739,15 @@ AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
|
|||||||
return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
|
return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineValueMap AffineParallelOp::getRangesValueMap() {
|
|
||||||
AffineValueMap out;
|
|
||||||
AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
|
|
||||||
&out);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
|
Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
|
||||||
|
if (hasMinMaxBounds())
|
||||||
|
return llvm::None;
|
||||||
|
|
||||||
// Try to convert all the ranges to constant expressions.
|
// Try to convert all the ranges to constant expressions.
|
||||||
SmallVector<int64_t, 8> out;
|
SmallVector<int64_t, 8> out;
|
||||||
AffineValueMap rangesValueMap = getRangesValueMap();
|
AffineValueMap rangesValueMap;
|
||||||
|
AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
|
||||||
|
&rangesValueMap);
|
||||||
out.reserve(rangesValueMap.getNumResults());
|
out.reserve(rangesValueMap.getNumResults());
|
||||||
for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
|
for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
|
||||||
auto expr = rangesValueMap.getResult(i);
|
auto expr = rangesValueMap.getResult(i);
|
||||||
@ -2780,12 +2821,32 @@ void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
|
|||||||
|
|
||||||
static LogicalResult verify(AffineParallelOp op) {
|
static LogicalResult verify(AffineParallelOp op) {
|
||||||
auto numDims = op.getNumDims();
|
auto numDims = op.getNumDims();
|
||||||
if (op.lowerBoundsMap().getNumResults() != numDims ||
|
if (op.lowerBoundsGroups().getNumElements() != numDims ||
|
||||||
op.upperBoundsMap().getNumResults() != numDims ||
|
op.upperBoundsGroups().getNumElements() != numDims ||
|
||||||
op.steps().size() != numDims ||
|
op.steps().size() != numDims ||
|
||||||
op.getBody()->getNumArguments() != numDims)
|
op.getBody()->getNumArguments() != numDims) {
|
||||||
return op.emitOpError("region argument count and num results of upper "
|
return op.emitOpError()
|
||||||
"bounds, lower bounds, and steps must all match");
|
<< "the number of region arguments ("
|
||||||
|
<< op.getBody()->getNumArguments()
|
||||||
|
<< ") and the number of map groups for lower ("
|
||||||
|
<< op.lowerBoundsGroups().getNumElements() << ") and upper bound ("
|
||||||
|
<< op.upperBoundsGroups().getNumElements()
|
||||||
|
<< "), and the number of steps (" << op.steps().size()
|
||||||
|
<< ") must all match";
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned expectedNumLBResults = 0;
|
||||||
|
for (APInt v : op.lowerBoundsGroups())
|
||||||
|
expectedNumLBResults += v.getZExtValue();
|
||||||
|
if (expectedNumLBResults != op.lowerBoundsMap().getNumResults())
|
||||||
|
return op.emitOpError() << "expected lower bounds map to have "
|
||||||
|
<< expectedNumLBResults << " results";
|
||||||
|
unsigned expectedNumUBResults = 0;
|
||||||
|
for (APInt v : op.upperBoundsGroups())
|
||||||
|
expectedNumUBResults += v.getZExtValue();
|
||||||
|
if (expectedNumUBResults != op.upperBoundsMap().getNumResults())
|
||||||
|
return op.emitOpError() << "expected upper bounds map to have "
|
||||||
|
<< expectedNumUBResults << " results";
|
||||||
|
|
||||||
if (op.reductions().size() != op.getNumResults())
|
if (op.reductions().size() != op.getNumResults())
|
||||||
return op.emitOpError("a reduction must be specified for each output");
|
return op.emitOpError("a reduction must be specified for each output");
|
||||||
@ -2844,13 +2905,44 @@ LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
|
|||||||
return canonicalizeLoopBounds(*this);
|
return canonicalizeLoopBounds(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Prints a lower(upper) bound of an affine parallel loop with max(min)
|
||||||
|
/// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
|
||||||
|
/// identifies which of the those expressions form max/min groups. `operands`
|
||||||
|
/// are the SSA values of dimensions and symbols and `keyword` is either "min"
|
||||||
|
/// or "max".
|
||||||
|
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
|
||||||
|
DenseIntElementsAttr group, ValueRange operands,
|
||||||
|
StringRef keyword) {
|
||||||
|
AffineMap map = mapAttr.getValue();
|
||||||
|
unsigned numDims = map.getNumDims();
|
||||||
|
ValueRange dimOperands = operands.take_front(numDims);
|
||||||
|
ValueRange symOperands = operands.drop_front(numDims);
|
||||||
|
unsigned start = 0;
|
||||||
|
for (llvm::APInt groupSize : group) {
|
||||||
|
if (start != 0)
|
||||||
|
p << ", ";
|
||||||
|
|
||||||
|
unsigned size = groupSize.getZExtValue();
|
||||||
|
if (size == 1) {
|
||||||
|
p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
|
||||||
|
++start;
|
||||||
|
} else {
|
||||||
|
p << keyword << '(';
|
||||||
|
AffineMap submap = map.getSliceMap(start, size);
|
||||||
|
p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
|
||||||
|
p << ')';
|
||||||
|
start += size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, AffineParallelOp op) {
|
static void print(OpAsmPrinter &p, AffineParallelOp op) {
|
||||||
p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
|
p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
|
||||||
p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(),
|
printMinMaxBound(p, op.lowerBoundsMapAttr(), op.lowerBoundsGroupsAttr(),
|
||||||
op.getLowerBoundsOperands());
|
op.getLowerBoundsOperands(), "max");
|
||||||
p << ") to (";
|
p << ") to (";
|
||||||
p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(),
|
printMinMaxBound(p, op.upperBoundsMapAttr(), op.upperBoundsGroupsAttr(),
|
||||||
op.getUpperBoundsOperands());
|
op.getUpperBoundsOperands(), "min");
|
||||||
p << ')';
|
p << ')';
|
||||||
SmallVector<int64_t, 8> steps = op.getSteps();
|
SmallVector<int64_t, 8> steps = op.getSteps();
|
||||||
bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
|
bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
|
||||||
@ -2875,39 +2967,171 @@ static void print(OpAsmPrinter &p, AffineParallelOp op) {
|
|||||||
op->getAttrs(),
|
op->getAttrs(),
|
||||||
/*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
|
/*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
|
||||||
AffineParallelOp::getLowerBoundsMapAttrName(),
|
AffineParallelOp::getLowerBoundsMapAttrName(),
|
||||||
|
AffineParallelOp::getLowerBoundsGroupsAttrName(),
|
||||||
AffineParallelOp::getUpperBoundsMapAttrName(),
|
AffineParallelOp::getUpperBoundsMapAttrName(),
|
||||||
|
AffineParallelOp::getUpperBoundsGroupsAttrName(),
|
||||||
AffineParallelOp::getStepsAttrName()});
|
AffineParallelOp::getStepsAttrName()});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Given a list of lists of parsed operands, populates `uniqueOperands` with
|
||||||
|
/// unique operands. Also populates `replacements with affine expressions of
|
||||||
|
/// `kind` that can be used to update affine maps previously accepting a
|
||||||
|
/// `operands` to accept `uniqueOperands` instead.
|
||||||
|
static void deduplicateAndResolveOperands(
|
||||||
|
OpAsmParser &parser,
|
||||||
|
ArrayRef<SmallVector<OpAsmParser::OperandType>> operands,
|
||||||
|
SmallVectorImpl<Value> &uniqueOperands,
|
||||||
|
SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
|
||||||
|
assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
|
||||||
|
"expected operands to be dim or symbol expression");
|
||||||
|
|
||||||
|
Type indexType = parser.getBuilder().getIndexType();
|
||||||
|
for (const auto &list : operands) {
|
||||||
|
SmallVector<Value> valueOperands;
|
||||||
|
parser.resolveOperands(list, indexType, valueOperands);
|
||||||
|
for (Value operand : valueOperands) {
|
||||||
|
unsigned pos = std::distance(uniqueOperands.begin(),
|
||||||
|
llvm::find(uniqueOperands, operand));
|
||||||
|
if (pos == uniqueOperands.size())
|
||||||
|
uniqueOperands.push_back(operand);
|
||||||
|
replacements.push_back(
|
||||||
|
kind == AffineExprKind::DimId
|
||||||
|
? getAffineDimExpr(pos, parser.getBuilder().getContext())
|
||||||
|
: getAffineSymbolExpr(pos, parser.getBuilder().getContext()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
enum class MinMaxKind { Min, Max };
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
/// Parses an affine map that can contain a min/max for groups of its results,
|
||||||
|
/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
|
||||||
|
/// `result` attributes with the map (flat list of expressions) and the grouping
|
||||||
|
/// (list of integers that specify how many expressions to put into each
|
||||||
|
/// min/max) attributes. Deduplicates repeated operands.
|
||||||
|
///
|
||||||
|
/// parallel-bound ::= `(` parallel-group-list `)`
|
||||||
|
/// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
|
||||||
|
/// parallel-group ::= simple-group | min-max-group
|
||||||
|
/// simple-group ::= expr-of-ssa-ids
|
||||||
|
/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
|
||||||
|
/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
|
||||||
|
///
|
||||||
|
/// Examples:
|
||||||
|
/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
|
||||||
|
/// (%0, max(%1 - 2 * %2))
|
||||||
|
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
|
||||||
|
OperationState &result,
|
||||||
|
MinMaxKind kind) {
|
||||||
|
constexpr llvm::StringLiteral tmpAttrName = "__pseudo_bound_map";
|
||||||
|
|
||||||
|
StringRef mapName = kind == MinMaxKind::Min
|
||||||
|
? AffineParallelOp::getUpperBoundsMapAttrName()
|
||||||
|
: AffineParallelOp::getLowerBoundsMapAttrName();
|
||||||
|
StringRef groupsName = kind == MinMaxKind::Min
|
||||||
|
? AffineParallelOp::getUpperBoundsGroupsAttrName()
|
||||||
|
: AffineParallelOp::getLowerBoundsGroupsAttrName();
|
||||||
|
|
||||||
|
if (failed(parser.parseLParen()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalRParen())) {
|
||||||
|
result.addAttribute(
|
||||||
|
mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
|
||||||
|
result.addAttribute(groupsName, parser.getBuilder().getI32VectorAttr({}));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> flatExprs;
|
||||||
|
SmallVector<SmallVector<OpAsmParser::OperandType>> flatDimOperands;
|
||||||
|
SmallVector<SmallVector<OpAsmParser::OperandType>> flatSymOperands;
|
||||||
|
SmallVector<int32_t> numMapsPerGroup;
|
||||||
|
SmallVector<OpAsmParser::OperandType> mapOperands;
|
||||||
|
do {
|
||||||
|
if (succeeded(parser.parseOptionalKeyword(
|
||||||
|
kind == MinMaxKind::Min ? "min" : "max"))) {
|
||||||
|
mapOperands.clear();
|
||||||
|
AffineMapAttr map;
|
||||||
|
if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrName,
|
||||||
|
result.attributes,
|
||||||
|
OpAsmParser::Delimiter::Paren)))
|
||||||
|
return failure();
|
||||||
|
result.attributes.erase(tmpAttrName);
|
||||||
|
llvm::append_range(flatExprs, map.getValue().getResults());
|
||||||
|
auto operandsRef = llvm::makeArrayRef(mapOperands);
|
||||||
|
auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
|
||||||
|
SmallVector<OpAsmParser::OperandType> dims(dimsRef.begin(),
|
||||||
|
dimsRef.end());
|
||||||
|
auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
|
||||||
|
SmallVector<OpAsmParser::OperandType> syms(symsRef.begin(),
|
||||||
|
symsRef.end());
|
||||||
|
flatDimOperands.append(map.getValue().getNumResults(), dims);
|
||||||
|
flatSymOperands.append(map.getValue().getNumResults(), syms);
|
||||||
|
numMapsPerGroup.push_back(map.getValue().getNumResults());
|
||||||
|
} else {
|
||||||
|
if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
|
||||||
|
flatSymOperands.emplace_back(),
|
||||||
|
flatExprs.emplace_back())))
|
||||||
|
return failure();
|
||||||
|
numMapsPerGroup.push_back(1);
|
||||||
|
}
|
||||||
|
} while (succeeded(parser.parseOptionalComma()));
|
||||||
|
|
||||||
|
if (failed(parser.parseRParen()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
unsigned totalNumDims = 0;
|
||||||
|
unsigned totalNumSyms = 0;
|
||||||
|
for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
|
||||||
|
unsigned numDims = flatDimOperands[i].size();
|
||||||
|
unsigned numSyms = flatSymOperands[i].size();
|
||||||
|
flatExprs[i] = flatExprs[i]
|
||||||
|
.shiftDims(numDims, totalNumDims)
|
||||||
|
.shiftSymbols(numSyms, totalNumSyms);
|
||||||
|
totalNumDims += numDims;
|
||||||
|
totalNumSyms += numSyms;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplicate map operands.
|
||||||
|
SmallVector<Value> dimOperands, symOperands;
|
||||||
|
SmallVector<AffineExpr> dimRplacements, symRepacements;
|
||||||
|
deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
|
||||||
|
dimRplacements, AffineExprKind::DimId);
|
||||||
|
deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
|
||||||
|
symRepacements, AffineExprKind::SymbolId);
|
||||||
|
|
||||||
|
result.operands.append(dimOperands.begin(), dimOperands.end());
|
||||||
|
result.operands.append(symOperands.begin(), symOperands.end());
|
||||||
|
|
||||||
|
Builder &builder = parser.getBuilder();
|
||||||
|
auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
|
||||||
|
parser.getBuilder().getContext());
|
||||||
|
flatMap = flatMap.replaceDimsAndSymbols(
|
||||||
|
dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
|
||||||
|
|
||||||
|
result.addAttribute(mapName, AffineMapAttr::get(flatMap));
|
||||||
|
result.addAttribute(groupsName, builder.getI32VectorAttr(numMapsPerGroup));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)`
|
// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
|
||||||
// `to` `(` map-of-ssa-ids `)` steps? region attr-dict?
|
// `to` parallel-bound steps? region attr-dict?
|
||||||
// steps ::= `steps` `(` integer-literals `)`
|
// steps ::= `steps` `(` integer-literals `)`
|
||||||
//
|
//
|
||||||
static ParseResult parseAffineParallelOp(OpAsmParser &parser,
|
static ParseResult parseAffineParallelOp(OpAsmParser &parser,
|
||||||
OperationState &result) {
|
OperationState &result) {
|
||||||
auto &builder = parser.getBuilder();
|
auto &builder = parser.getBuilder();
|
||||||
auto indexType = builder.getIndexType();
|
auto indexType = builder.getIndexType();
|
||||||
AffineMapAttr lowerBoundsAttr, upperBoundsAttr;
|
|
||||||
SmallVector<OpAsmParser::OperandType, 4> ivs;
|
SmallVector<OpAsmParser::OperandType, 4> ivs;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands;
|
|
||||||
SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands;
|
|
||||||
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
|
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
|
||||||
OpAsmParser::Delimiter::Paren) ||
|
OpAsmParser::Delimiter::Paren) ||
|
||||||
parser.parseEqual() ||
|
parser.parseEqual() ||
|
||||||
parser.parseAffineMapOfSSAIds(
|
parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
|
||||||
lowerBoundsMapOperands, lowerBoundsAttr,
|
|
||||||
AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes,
|
|
||||||
OpAsmParser::Delimiter::Paren) ||
|
|
||||||
parser.resolveOperands(lowerBoundsMapOperands, indexType,
|
|
||||||
result.operands) ||
|
|
||||||
parser.parseKeyword("to") ||
|
parser.parseKeyword("to") ||
|
||||||
parser.parseAffineMapOfSSAIds(
|
parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
|
||||||
upperBoundsMapOperands, upperBoundsAttr,
|
|
||||||
AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes,
|
|
||||||
OpAsmParser::Delimiter::Paren) ||
|
|
||||||
parser.resolveOperands(upperBoundsMapOperands, indexType,
|
|
||||||
result.operands))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
AffineMapAttr stepsMapAttr;
|
AffineMapAttr stepsMapAttr;
|
||||||
|
@ -21,6 +21,10 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
void mlir::normalizeAffineParallel(AffineParallelOp op) {
|
void mlir::normalizeAffineParallel(AffineParallelOp op) {
|
||||||
|
// Loops with min/max in bounds are not normalized at the moment.
|
||||||
|
if (op.hasMinMaxBounds())
|
||||||
|
return;
|
||||||
|
|
||||||
AffineMap lbMap = op.lowerBoundsMap();
|
AffineMap lbMap = op.lowerBoundsMap();
|
||||||
SmallVector<int64_t, 8> steps = op.getSteps();
|
SmallVector<int64_t, 8> steps = op.getSteps();
|
||||||
// No need to do any work if the parallel op is already normalized.
|
// No need to do any work if the parallel op is already normalized.
|
||||||
@ -34,7 +38,9 @@ void mlir::normalizeAffineParallel(AffineParallelOp op) {
|
|||||||
if (isAlreadyNormalized)
|
if (isAlreadyNormalized)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
AffineValueMap ranges = op.getRangesValueMap();
|
AffineValueMap ranges;
|
||||||
|
AffineValueMap::difference(op.getUpperBoundsValueMap(),
|
||||||
|
op.getLowerBoundsValueMap(), &ranges);
|
||||||
auto builder = OpBuilder::atBlockBegin(op.getBody());
|
auto builder = OpBuilder::atBlockBegin(op.getBody());
|
||||||
auto zeroExpr = builder.getAffineConstantExpr(0);
|
auto zeroExpr = builder.getAffineConstantExpr(0);
|
||||||
SmallVector<AffineExpr, 8> lbExprs;
|
SmallVector<AffineExpr, 8> lbExprs;
|
||||||
|
@ -145,47 +145,21 @@ mlir::affineParallelize(AffineForOp forOp,
|
|||||||
|
|
||||||
Location loc = forOp.getLoc();
|
Location loc = forOp.getLoc();
|
||||||
OpBuilder outsideBuilder(forOp);
|
OpBuilder outsideBuilder(forOp);
|
||||||
|
|
||||||
// If a loop has a 'max' in the lower bound, emit it outside the parallel loop
|
|
||||||
// as it does not have implicit 'max' behavior.
|
|
||||||
AffineMap lowerBoundMap = forOp.getLowerBoundMap();
|
AffineMap lowerBoundMap = forOp.getLowerBoundMap();
|
||||||
ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
|
ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
|
||||||
AffineMap upperBoundMap = forOp.getUpperBoundMap();
|
AffineMap upperBoundMap = forOp.getUpperBoundMap();
|
||||||
ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
|
ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
|
||||||
|
|
||||||
bool needsMax = lowerBoundMap.getNumResults() > 1;
|
|
||||||
bool needsMin = upperBoundMap.getNumResults() > 1;
|
|
||||||
AffineMap identityMap;
|
|
||||||
if (needsMax || needsMin) {
|
|
||||||
if (forOp->getParentOp() &&
|
|
||||||
!forOp->getParentOp()->hasTrait<OpTrait::AffineScope>())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
identityMap = AffineMap::getMultiDimIdentityMap(1, loc->getContext());
|
|
||||||
}
|
|
||||||
if (needsMax) {
|
|
||||||
auto maxOp = outsideBuilder.create<AffineMaxOp>(loc, lowerBoundMap,
|
|
||||||
lowerBoundOperands);
|
|
||||||
lowerBoundMap = identityMap;
|
|
||||||
lowerBoundOperands = maxOp->getResults();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Same for the upper bound.
|
|
||||||
if (needsMin) {
|
|
||||||
auto minOp = outsideBuilder.create<AffineMinOp>(loc, upperBoundMap,
|
|
||||||
upperBoundOperands);
|
|
||||||
upperBoundMap = identityMap;
|
|
||||||
upperBoundOperands = minOp->getResults();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creating empty 1-D affine.parallel op.
|
// Creating empty 1-D affine.parallel op.
|
||||||
auto reducedValues = llvm::to_vector<4>(llvm::map_range(
|
auto reducedValues = llvm::to_vector<4>(llvm::map_range(
|
||||||
parallelReductions, [](const LoopReduction &red) { return red.value; }));
|
parallelReductions, [](const LoopReduction &red) { return red.value; }));
|
||||||
auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
|
auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
|
||||||
parallelReductions, [](const LoopReduction &red) { return red.kind; }));
|
parallelReductions, [](const LoopReduction &red) { return red.kind; }));
|
||||||
AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
|
AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
|
||||||
loc, ValueRange(reducedValues).getTypes(), reductionKinds, lowerBoundMap,
|
loc, ValueRange(reducedValues).getTypes(), reductionKinds,
|
||||||
lowerBoundOperands, upperBoundMap, upperBoundOperands);
|
llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands,
|
||||||
|
llvm::makeArrayRef(upperBoundMap), upperBoundOperands,
|
||||||
|
llvm::makeArrayRef(forOp.getStep()));
|
||||||
// Steal the body of the old affine for op.
|
// Steal the body of the old affine for op.
|
||||||
newPloop.region().takeBody(forOp.region());
|
newPloop.region().takeBody(forOp.region());
|
||||||
Operation *yieldOp = &newPloop.getBody()->back();
|
Operation *yieldOp = &newPloop.getBody()->back();
|
||||||
|
@ -494,6 +494,11 @@ AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const {
|
|||||||
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
|
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const {
|
||||||
|
return AffineMap::get(getNumDims(), getNumSymbols(),
|
||||||
|
getResults().slice(start, length), getContext());
|
||||||
|
}
|
||||||
|
|
||||||
AffineMap AffineMap::getMajorSubMap(unsigned numResults) const {
|
AffineMap AffineMap::getMajorSubMap(unsigned numResults) const {
|
||||||
if (numResults == 0)
|
if (numResults == 0)
|
||||||
return AffineMap();
|
return AffineMap();
|
||||||
|
@ -469,6 +469,7 @@ private:
|
|||||||
/// The following are hooks of `OpAsmPrinter` that are not necessary for
|
/// The following are hooks of `OpAsmPrinter` that are not necessary for
|
||||||
/// determining potential aliases.
|
/// determining potential aliases.
|
||||||
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
|
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
|
||||||
|
void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
|
||||||
void printNewline() override {}
|
void printNewline() override {}
|
||||||
void printOperand(Value) override {}
|
void printOperand(Value) override {}
|
||||||
void printOperand(Value, raw_ostream &os) override {
|
void printOperand(Value, raw_ostream &os) override {
|
||||||
@ -2351,6 +2352,11 @@ public:
|
|||||||
void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
||||||
ValueRange operands) override;
|
ValueRange operands) override;
|
||||||
|
|
||||||
|
/// Print the given affine expression with the symbol and dimension operands
|
||||||
|
/// printed inline with the expression.
|
||||||
|
void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
|
||||||
|
ValueRange symOperands) override;
|
||||||
|
|
||||||
/// Print the given string as a symbol reference.
|
/// Print the given string as a symbol reference.
|
||||||
void printSymbolName(StringRef symbolRef) override {
|
void printSymbolName(StringRef symbolRef) override {
|
||||||
::printSymbolReference(symbolRef, os);
|
::printSymbolReference(symbolRef, os);
|
||||||
@ -2590,6 +2596,19 @@ void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
|
||||||
|
ValueRange dimOperands,
|
||||||
|
ValueRange symOperands) {
|
||||||
|
auto printValueName = [&](unsigned pos, bool isSymbol) {
|
||||||
|
if (!isSymbol)
|
||||||
|
return printValueID(dimOperands[pos]);
|
||||||
|
os << "symbol(";
|
||||||
|
printValueID(symOperands[pos]);
|
||||||
|
os << ')';
|
||||||
|
};
|
||||||
|
printAffineExpr(expr, printValueName);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// print and dump methods
|
// print and dump methods
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -55,6 +55,7 @@ public:
|
|||||||
IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
|
IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
|
||||||
ParseResult parseAffineMapOfSSAIds(AffineMap &map,
|
ParseResult parseAffineMapOfSSAIds(AffineMap &map,
|
||||||
OpAsmParser::Delimiter delimiter);
|
OpAsmParser::Delimiter delimiter);
|
||||||
|
ParseResult parseAffineExprOfSSAIds(AffineExpr &expr);
|
||||||
void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
|
void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
|
||||||
unsigned &numDims);
|
unsigned &numDims);
|
||||||
|
|
||||||
@ -579,6 +580,12 @@ AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse an AffineExpr where the dim and symbol identifiers are SSA ids.
|
||||||
|
ParseResult AffineParser::parseAffineExprOfSSAIds(AffineExpr &expr) {
|
||||||
|
expr = parseAffineExpr();
|
||||||
|
return success(expr != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse the range and sizes affine map definition inline.
|
/// Parse the range and sizes affine map definition inline.
|
||||||
///
|
///
|
||||||
/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
|
/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
|
||||||
@ -724,3 +731,12 @@ Parser::parseAffineMapOfSSAIds(AffineMap &map,
|
|||||||
return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
|
return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
|
||||||
.parseAffineMapOfSSAIds(map, delimiter);
|
.parseAffineMapOfSSAIds(map, delimiter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse an AffineExpr of SSA ids. The callback `parseElement` is used to parse
|
||||||
|
/// SSA value uses encountered while parsing.
|
||||||
|
ParseResult
|
||||||
|
Parser::parseAffineExprOfSSAIds(AffineExpr &expr,
|
||||||
|
function_ref<ParseResult(bool)> parseElement) {
|
||||||
|
return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
|
||||||
|
.parseAffineExprOfSSAIds(expr);
|
||||||
|
}
|
||||||
|
@ -1513,6 +1513,25 @@ public:
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse an AffineExpr of SSA ids.
|
||||||
|
ParseResult
|
||||||
|
parseAffineExprOfSSAIds(SmallVectorImpl<OperandType> &dimOperands,
|
||||||
|
SmallVectorImpl<OperandType> &symbOperands,
|
||||||
|
AffineExpr &expr) override {
|
||||||
|
auto parseElement = [&](bool isSymbol) -> ParseResult {
|
||||||
|
OperandType operand;
|
||||||
|
if (parseOperand(operand))
|
||||||
|
return failure();
|
||||||
|
if (isSymbol)
|
||||||
|
symbOperands.push_back(operand);
|
||||||
|
else
|
||||||
|
dimOperands.push_back(operand);
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
return parser.parseAffineExprOfSSAIds(expr, parseElement);
|
||||||
|
}
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Region Parsing
|
// Region Parsing
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
@ -268,6 +268,11 @@ public:
|
|||||||
function_ref<ParseResult(bool)> parseElement,
|
function_ref<ParseResult(bool)> parseElement,
|
||||||
OpAsmParser::Delimiter delimiter);
|
OpAsmParser::Delimiter delimiter);
|
||||||
|
|
||||||
|
/// Parse an AffineExpr where dim and symbol identifiers are SSA ids.
|
||||||
|
ParseResult
|
||||||
|
parseAffineExprOfSSAIds(AffineExpr &expr,
|
||||||
|
function_ref<ParseResult(bool)> parseElement);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// The Parser is subclassed and reinstantiated. Do not add additional
|
/// The Parser is subclassed and reinstantiated. Do not add additional
|
||||||
/// non-trivial state here, add it to the ParserState class.
|
/// non-trivial state here, add it to the ParserState class.
|
||||||
|
@ -740,8 +740,8 @@ func @affine_parallel_simple(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) ->
|
|||||||
}
|
}
|
||||||
// CHECK-LABEL: func @affine_parallel_simple
|
// CHECK-LABEL: func @affine_parallel_simple
|
||||||
// CHECK: %[[LOWER_1:.*]] = constant 0 : index
|
// CHECK: %[[LOWER_1:.*]] = constant 0 : index
|
||||||
// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index
|
|
||||||
// CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index
|
// CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index
|
||||||
|
// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index
|
||||||
// CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index
|
// CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index
|
||||||
// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index
|
// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index
|
||||||
// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index
|
// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index
|
||||||
@ -800,8 +800,8 @@ func @affine_parallel_with_reductions(%arg0: memref<3x3xf32>, %arg1: memref<3x3x
|
|||||||
}
|
}
|
||||||
// CHECK-LABEL: func @affine_parallel_with_reductions
|
// CHECK-LABEL: func @affine_parallel_with_reductions
|
||||||
// CHECK: %[[LOWER_1:.*]] = constant 0 : index
|
// CHECK: %[[LOWER_1:.*]] = constant 0 : index
|
||||||
// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index
|
|
||||||
// CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index
|
// CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index
|
||||||
|
// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index
|
||||||
// CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index
|
// CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index
|
||||||
// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index
|
// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index
|
||||||
// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index
|
// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index
|
||||||
@ -841,8 +841,8 @@ func @affine_parallel_with_reductions_f64(%arg0: memref<3x3xf64>, %arg1: memref<
|
|||||||
}
|
}
|
||||||
// CHECK-LABEL: @affine_parallel_with_reductions_f64
|
// CHECK-LABEL: @affine_parallel_with_reductions_f64
|
||||||
// CHECK: %[[LOWER_1:.*]] = constant 0 : index
|
// CHECK: %[[LOWER_1:.*]] = constant 0 : index
|
||||||
// CHECK: %[[LOWER_2:.*]] = constant 0 : index
|
|
||||||
// CHECK: %[[UPPER_1:.*]] = constant 2 : index
|
// CHECK: %[[UPPER_1:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[LOWER_2:.*]] = constant 0 : index
|
||||||
// CHECK: %[[UPPER_2:.*]] = constant 2 : index
|
// CHECK: %[[UPPER_2:.*]] = constant 2 : index
|
||||||
// CHECK: %[[STEP_1:.*]] = constant 1 : index
|
// CHECK: %[[STEP_1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[STEP_2:.*]] = constant 1 : index
|
// CHECK: %[[STEP_2:.*]] = constant 1 : index
|
||||||
@ -880,8 +880,8 @@ func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: memref<
|
|||||||
}
|
}
|
||||||
// CHECK-LABEL: @affine_parallel_with_reductions_i64
|
// CHECK-LABEL: @affine_parallel_with_reductions_i64
|
||||||
// CHECK: %[[LOWER_1:.*]] = constant 0 : index
|
// CHECK: %[[LOWER_1:.*]] = constant 0 : index
|
||||||
// CHECK: %[[LOWER_2:.*]] = constant 0 : index
|
|
||||||
// CHECK: %[[UPPER_1:.*]] = constant 2 : index
|
// CHECK: %[[UPPER_1:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[LOWER_2:.*]] = constant 0 : index
|
||||||
// CHECK: %[[UPPER_2:.*]] = constant 2 : index
|
// CHECK: %[[UPPER_2:.*]] = constant 2 : index
|
||||||
// CHECK: %[[STEP_1:.*]] = constant 1 : index
|
// CHECK: %[[STEP_1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[STEP_2:.*]] = constant 1 : index
|
// CHECK: %[[STEP_2:.*]] = constant 1 : index
|
||||||
|
@ -197,7 +197,7 @@ func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||||
// expected-error@+1 {{region argument count and num results of upper bounds, lower bounds, and steps must all match}}
|
// expected-error@+1 {{the number of region arguments (1) and the number of map groups for lower (2) and upper bound (2), and the number of steps (2) must all match}}
|
||||||
affine.parallel (%i) = (0, 0) to (100, 100) step (10, 10) {
|
affine.parallel (%i) = (0, 0) to (100, 100) step (10, 10) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,7 +205,7 @@ func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||||
// expected-error@+1 {{region argument count and num results of upper bounds, lower bounds, and steps must all match}}
|
// expected-error@+1 {{the number of region arguments (2) and the number of map groups for lower (1) and upper bound (2), and the number of steps (2) must all match}}
|
||||||
affine.parallel (%i, %j) = (0) to (100, 100) step (10, 10) {
|
affine.parallel (%i, %j) = (0) to (100, 100) step (10, 10) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -213,7 +213,7 @@ func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||||
// expected-error@+1 {{region argument count and num results of upper bounds, lower bounds, and steps must all match}}
|
// expected-error@+1 {{the number of region arguments (2) and the number of map groups for lower (2) and upper bound (1), and the number of steps (2) must all match}}
|
||||||
affine.parallel (%i, %j) = (0, 0) to (100) step (10, 10) {
|
affine.parallel (%i, %j) = (0, 0) to (100) step (10, 10) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -221,7 +221,7 @@ func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||||
// expected-error@+1 {{region argument count and num results of upper bounds, lower bounds, and steps must all match}}
|
// expected-error@+1 {{the number of region arguments (2) and the number of map groups for lower (2) and upper bound (2), and the number of steps (1) must all match}}
|
||||||
affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10) {
|
affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -169,6 +169,21 @@ func @parallel(%A : memref<100x100xf32>, %N : index) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @parallel_min_max
|
||||||
|
// CHECK: %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index, %[[D:.*]]: index
|
||||||
|
func @parallel_min_max(%a: index, %b: index, %c: index, %d: index) {
|
||||||
|
// CHECK: affine.parallel (%{{.*}}, %{{.*}}, %{{.*}}) =
|
||||||
|
// CHECK: (max(%[[A]], %[[B]])
|
||||||
|
// CHECK: to (%[[C]], min(%[[C]], %[[D]]), %[[B]])
|
||||||
|
affine.parallel (%i, %j, %k) = (max(%a, %b), %b, max(%a, %c))
|
||||||
|
to (%c, min(%c, %d), %b) {
|
||||||
|
affine.yield
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @affine_if
|
// CHECK-LABEL: func @affine_if
|
||||||
func @affine_if() -> f32 {
|
func @affine_if() -> f32 {
|
||||||
// CHECK: %[[ZERO:.*]] = constant {{.*}} : f32
|
// CHECK: %[[ZERO:.*]] = constant {{.*}} : f32
|
||||||
|
@ -120,9 +120,7 @@ func @non_affine_load() {
|
|||||||
// CHECK-LABEL: for_with_minmax
|
// CHECK-LABEL: for_with_minmax
|
||||||
func @for_with_minmax(%m: memref<?xf32>, %lb0: index, %lb1: index,
|
func @for_with_minmax(%m: memref<?xf32>, %lb0: index, %lb1: index,
|
||||||
%ub0: index, %ub1: index) {
|
%ub0: index, %ub1: index) {
|
||||||
// CHECK: %[[lb:.*]] = affine.max
|
// CHECK: affine.parallel (%{{.*}}) = (max(%{{.*}}, %{{.*}})) to (min(%{{.*}}, %{{.*}}))
|
||||||
// CHECK: %[[ub:.*]] = affine.min
|
|
||||||
// CHECK: affine.parallel (%{{.*}}) = (%[[lb]]) to (%[[ub]])
|
|
||||||
affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %lb1)
|
affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %lb1)
|
||||||
to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) {
|
to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) {
|
||||||
affine.load %m[%i] : memref<?xf32>
|
affine.load %m[%i] : memref<?xf32>
|
||||||
@ -133,12 +131,9 @@ func @for_with_minmax(%m: memref<?xf32>, %lb0: index, %lb1: index,
|
|||||||
// CHECK-LABEL: nested_for_with_minmax
|
// CHECK-LABEL: nested_for_with_minmax
|
||||||
func @nested_for_with_minmax(%m: memref<?xf32>, %lb0: index,
|
func @nested_for_with_minmax(%m: memref<?xf32>, %lb0: index,
|
||||||
%ub0: index, %ub1: index) {
|
%ub0: index, %ub1: index) {
|
||||||
// CHECK: affine.parallel
|
// CHECK: affine.parallel (%[[I:.*]]) =
|
||||||
affine.for %j = 0 to 10 {
|
affine.for %j = 0 to 10 {
|
||||||
// Cannot parallelize the inner loop because we would need to compute
|
// CHECK: affine.parallel (%{{.*}}) = (max(%{{.*}}, %[[I]])) to (min(%{{.*}}, %{{.*}}))
|
||||||
// affine.max for its lower bound inside the loop, and that is not (yet)
|
|
||||||
// considered as a valid affine dimension.
|
|
||||||
// CHECK: affine.for
|
|
||||||
affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j)
|
affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j)
|
||||||
to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) {
|
to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) {
|
||||||
affine.load %m[%i] : memref<?xf32>
|
affine.load %m[%i] : memref<?xf32>
|
||||||
@ -236,3 +231,20 @@ func @use_in_backward_slice() {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// REDUCE-LABEL: @nested_min_max
|
||||||
|
// CHECK-LABEL: @nested_min_max
|
||||||
|
// CHECK: (%{{.*}}, %[[LB0:.*]]: index, %[[UB0:.*]]: index, %[[UB1:.*]]: index)
|
||||||
|
func @nested_min_max(%m: memref<?xf32>, %lb0: index,
|
||||||
|
%ub0: index, %ub1: index) {
|
||||||
|
// CHECK: affine.parallel (%[[J:.*]]) =
|
||||||
|
affine.for %j = 0 to 10 {
|
||||||
|
// CHECK: affine.parallel (%{{.*}}) = (max(%[[LB0]], %[[J]]))
|
||||||
|
// CHECK: to (min(%[[UB0]], %[[UB1]]))
|
||||||
|
affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j)
|
||||||
|
to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) {
|
||||||
|
affine.load %m[%i] : memref<?xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user