llvm-capstone/mlir/lib/IR/Statement.cpp
Uday Bondhugula 041817a45e Introduce loop body skewing / loop pipelining / loop shifting utility.
- loopBodySkew shifts statements of a loop body by stmt-wise delays, and is
  typically meant to be used to:
  - allow overlap of non-blocking start/wait until completion operations with
    other computation
  - allow shifting of statements (for better register
    reuse/locality/parallelism)
  - software pipelining (when applied to the innermost loop)
- an additional argument specifies whether to unroll the prologue and epilogue.
- add method to check SSA dominance preservation.
- add a fake loop pipeline pass to test this utility.

Sample input/output are below. While on this, fix/add following:

- fix minor bug in getAddMulPureAffineExpr
- add additional builder methods for common affine map cases
- fix const_operand_iterator's for ForStmt, etc. When there is no such thing
  as 'const MLValue', the iterator shouldn't be returning const MLValue's.
  Returning MLValue is const correct.

Sample input/output examples:

1) Simplest case: shift second statement by one.

Input:

for %i = 0 to 7 {
  %y = "foo"(%i) : (affineint) -> affineint
  %x = "bar"(%i) : (affineint) -> affineint
}

Output:

#map0 = (d0) -> (d0 - 1)
mlfunc @loop_nest_simple1() {
  %c8 = constant 8 : affineint
  %c0 = constant 0 : affineint
  %0 = "foo"(%c0) : (affineint) -> affineint
  for %i0 = 1 to 7 {
    %1 = "foo"(%i0) : (affineint) -> affineint
    %2 = affine_apply #map0(%i0)
    %3 = "bar"(%2) : (affineint) -> affineint
  }
  %4 = affine_apply #map0(%c8)
  %5 = "bar"(%4) : (affineint) -> affineint
  return
}

2) DMA overlap: shift dma.wait and compute by one.

Input
  for %i = 0 to 7 {
    %pingpong = affine_apply (d0) -> (d0 mod 2) (%i)
    "dma.enqueue"(%pingpong) : (affineint) -> affineint
    %pongping = affine_apply (d0) -> (d0 mod 2) (%i)
    "dma.wait"(%pongping) : (affineint) -> affineint
    "compute1"(%pongping) : (affineint) -> affineint
  }

Output

#map0 = (d0) -> (d0 mod 2)
#map1 = (d0) -> (d0 - 1)
#map2 = ()[s0] -> (s0 + 7)
mlfunc @loop_nest_dma() {
  %c8 = constant 8 : affineint
  %c0 = constant 0 : affineint
  %0 = affine_apply #map0(%c0)
  %1 = "dma.enqueue"(%0) : (affineint) -> affineint
  for %i0 = 1 to 7 {
    %2 = affine_apply #map0(%i0)
    %3 = "dma.enqueue"(%2) : (affineint) -> affineint
    %4 = affine_apply #map1(%i0)
    %5 = affine_apply #map0(%4)
    %6 = "dma.wait"(%5) : (affineint) -> affineint
    %7 = "compute1"(%5) : (affineint) -> affineint
  }
  %8 = affine_apply #map1(%c8)
  %9 = affine_apply #map0(%8)
  %10 = "dma.wait"(%9) : (affineint) -> affineint
  %11 = "compute1"(%9) : (affineint) -> affineint
  return
}

3) With arbitrary affine bound maps:

Shift last two statements by two.

Input:

  for %i = %N to ()[s0] -> (s0 + 7)()[%N] {
    %y = "foo"(%i) : (affineint) -> affineint
    %x = "bar"(%i) : (affineint) -> affineint
    %z = "foo_bar"(%i) : (affineint) -> (affineint)
    "bar_foo"(%i) : (affineint) -> (affineint)
  }

Output

#map0 = ()[s0] -> (s0 + 1)
#map1 = ()[s0] -> (s0 + 2)
#map2 = ()[s0] -> (s0 + 7)
#map3 = (d0) -> (d0 - 2)
#map4 = ()[s0] -> (s0 + 8)
#map5 = ()[s0] -> (s0 + 9)

  for %i0 = %arg0 to #map0()[%arg0] {
    %0 = "foo"(%i0) : (affineint) -> affineint
    %1 = "bar"(%i0) : (affineint) -> affineint
  }
  for %i1 = #map1()[%arg0] to #map2()[%arg0] {
    %2 = "foo"(%i1) : (affineint) -> affineint
    %3 = "bar"(%i1) : (affineint) -> affineint
    %4 = affine_apply #map3(%i1)
    %5 = "foo_bar"(%4) : (affineint) -> affineint
    %6 = "bar_foo"(%4) : (affineint) -> affineint
  }
  for %i2 = #map4()[%arg0] to #map5()[%arg0] {
    %7 = affine_apply #map3(%i2)
    %8 = "foo_bar"(%7) : (affineint) -> affineint
    %9 = "bar_foo"(%7) : (affineint) -> affineint
  }

4) Shift one by zero, second by one, third by two

  for %i = 0 to 7 {
    %y = "foo"(%i) : (affineint) -> affineint
    %x = "bar"(%i) : (affineint) -> affineint
    %z = "foobar"(%i) : (affineint) -> affineint
  }

#map0 = (d0) -> (d0 - 1)
#map1 = (d0) -> (d0 - 2)
#map2 = ()[s0] -> (s0 + 7)

  %c9 = constant 9 : affineint
  %c8 = constant 8 : affineint
  %c1 = constant 1 : affineint
  %c0 = constant 0 : affineint
  %0 = "foo"(%c0) : (affineint) -> affineint
  %1 = "foo"(%c1) : (affineint) -> affineint
  %2 = affine_apply #map0(%c1)
  %3 = "bar"(%2) : (affineint) -> affineint
  for %i0 = 2 to 7 {
    %4 = "foo"(%i0) : (affineint) -> affineint
    %5 = affine_apply #map0(%i0)
    %6 = "bar"(%5) : (affineint) -> affineint
    %7 = affine_apply #map1(%i0)
    %8 = "foobar"(%7) : (affineint) -> affineint
  }
  %9 = affine_apply #map0(%c8)
  %10 = "bar"(%9) : (affineint) -> affineint
  %11 = affine_apply #map1(%c8)
  %12 = "foobar"(%11) : (affineint) -> affineint
  %13 = affine_apply #map1(%c9)
  %14 = "foobar"(%13) : (affineint) -> affineint

5) SSA dominance violated; no shifting if a shift is specified for the second
statement.

  for %i = 0 to 7 {
    %x = "foo"(%i) : (affineint) -> affineint
    "bar"(%x) : (affineint) -> affineint
  }

PiperOrigin-RevId: 214975731
2019-03-29 13:21:26 -07:00

494 lines
17 KiB
C++

//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// StmtResult
//===------------------------------------------------------------------===//
/// Return the result number of this result.
unsigned StmtResult::getResultNumber() const {
// Results are always stored consecutively, so use pointer subtraction to
// figure out what number this is.
return this - &getOwner()->getStmtResults()[0];
}
//===----------------------------------------------------------------------===//
// Statement
//===------------------------------------------------------------------===//
// Statements are deleted through the destroy() member because we don't have
// a virtual destructor.
Statement::~Statement() {
assert(block == nullptr && "statement destroyed but still in a block");
}
/// Destroy this statement or one of its subclasses.
void Statement::destroy() {
switch (this->getKind()) {
case Kind::Operation:
cast<OperationStmt>(this)->destroy();
break;
case Kind::For:
delete cast<ForStmt>(this);
break;
case Kind::If:
delete cast<IfStmt>(this);
break;
}
}
/// Return the context this operation is associated with.
MLIRContext *Statement::getContext() const {
// Work a bit to avoid calling findFunction() and getting its context.
switch (getKind()) {
case Kind::Operation:
return cast<OperationStmt>(this)->getContext();
case Kind::For:
return cast<ForStmt>(this)->getContext();
case Kind::If:
return cast<IfStmt>(this)->getContext();
}
}
Statement *Statement::getParentStmt() const {
return block ? block->getContainingStmt() : nullptr;
}
MLFunction *Statement::findFunction() const {
return block ? block->findFunction() : nullptr;
}
MLValue *Statement::getOperand(unsigned idx) {
return getStmtOperand(idx).get();
}
const MLValue *Statement::getOperand(unsigned idx) const {
return getStmtOperand(idx).get();
}
void Statement::setOperand(unsigned idx, MLValue *value) {
getStmtOperand(idx).set(value);
}
unsigned Statement::getNumOperands() const {
switch (getKind()) {
case Kind::Operation:
return cast<OperationStmt>(this)->getNumOperands();
case Kind::For:
return cast<ForStmt>(this)->getNumOperands();
case Kind::If:
return cast<IfStmt>(this)->getNumOperands();
}
}
MutableArrayRef<StmtOperand> Statement::getStmtOperands() {
switch (getKind()) {
case Kind::Operation:
return cast<OperationStmt>(this)->getStmtOperands();
case Kind::For:
return cast<ForStmt>(this)->getStmtOperands();
case Kind::If:
return cast<IfStmt>(this)->getStmtOperands();
}
}
/// Emit a note about this statement, reporting up to any diagnostic
/// handlers that may be listening.
void Statement::emitNote(const Twine &message) const {
getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Note);
}
/// Emit a warning about this statement, reporting up to any diagnostic
/// handlers that may be listening.
void Statement::emitWarning(const Twine &message) const {
getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Warning);
}
/// Emit an error about fatal conditions with this statement, reporting up to
/// any diagnostic handlers that may be listening. NOTE: This may terminate
/// the containing application, only use when the IR is in an inconsistent
/// state.
void Statement::emitError(const Twine &message) const {
getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Error);
}
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
size_t Offset(
size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
Offset);
}
/// This is a trait method invoked when a statement is added to a block. We
/// keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
assert(!stmt->getBlock() && "already in a statement block!");
stmt->block = getContainingBlock();
}
/// This is a trait method invoked when a statement is removed from a block.
/// We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
Statement *stmt) {
assert(stmt->block && "not already in a statement block!");
stmt->block = nullptr;
}
/// This is a trait method invoked when a statement is moved from one block
/// to another. We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
ilist_traits<Statement> &otherList, stmt_iterator first,
stmt_iterator last) {
// If we are transferring statements within the same block, the block
// pointer doesn't need to be updated.
StmtBlock *curParent = getContainingBlock();
if (curParent == otherList.getContainingBlock())
return;
// Update the 'block' member of each statement.
for (; first != last; ++first)
first->block = curParent;
}
/// Remove this statement (and its descendants) from its StmtBlock and delete
/// all of them.
void Statement::eraseFromBlock() {
assert(getBlock() && "Statement has no block");
getBlock()->getStatements().erase(this);
}
//===----------------------------------------------------------------------===//
// OperationStmt
//===----------------------------------------------------------------------===//
/// Create a new OperationStmt with the specific fields.
OperationStmt *OperationStmt::create(Location *location, Identifier name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context) {
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
resultTypes.size());
void *rawMem = malloc(byteSize);
// Initialize the OperationStmt part of the statement.
auto stmt = ::new (rawMem) OperationStmt(
location, name, operands.size(), resultTypes.size(), attributes, context);
// Initialize the operands and results.
auto stmtOperands = stmt->getStmtOperands();
for (unsigned i = 0, e = operands.size(); i != e; ++i)
new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
auto stmtResults = stmt->getStmtResults();
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
return stmt;
}
OperationStmt::OperationStmt(Location *location, Identifier name,
unsigned numOperands, unsigned numResults,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
: Operation(/*isInstruction=*/false, name, attributes, context),
Statement(Kind::Operation, location), numOperands(numOperands),
numResults(numResults) {}
OperationStmt::~OperationStmt() {
// Explicitly run the destructors for the operands and results.
for (auto &operand : getStmtOperands())
operand.~StmtOperand();
for (auto &result : getStmtResults())
result.~StmtResult();
}
void OperationStmt::destroy() {
this->~OperationStmt();
free(this);
}
/// Return the context this operation is associated with.
MLIRContext *OperationStmt::getContext() const {
// If we have a result or operand type, that is a constant time way to get
// to the context.
if (getNumResults())
return getResult(0)->getType()->getContext();
if (getNumOperands())
return getOperand(0)->getType()->getContext();
// In the very odd case where we have no operands or results, fall back to
// doing a find.
return findFunction()->getContext();
}
bool OperationStmt::isReturn() const { return is<ReturnOp>(); }
//===----------------------------------------------------------------------===//
// ForStmt
//===----------------------------------------------------------------------===//
ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands,
AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
AffineMap *ubMap, int64_t step, MLIRContext *context) {
assert(lbOperands.size() == lbMap->getNumInputs() &&
"lower bound operand count does not match the affine map");
assert(ubOperands.size() == ubMap->getNumInputs() &&
"upper bound operand count does not match the affine map");
assert(step > 0 && "step has to be a positive integer constant");
unsigned numOperands = lbOperands.size() + ubOperands.size();
ForStmt *stmt =
new ForStmt(location, numOperands, lbMap, ubMap, step, context);
unsigned i = 0;
for (unsigned e = lbOperands.size(); i != e; ++i)
stmt->operands.emplace_back(StmtOperand(stmt, lbOperands[i]));
for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j)
stmt->operands.emplace_back(StmtOperand(stmt, ubOperands[j]));
return stmt;
}
ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
AffineMap *ubMap, int64_t step, MLIRContext *context)
: Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
operands.reserve(numOperands);
}
const AffineBound ForStmt::getLowerBound() const {
return AffineBound(*this, 0, lbMap->getNumInputs(), lbMap);
}
const AffineBound ForStmt::getUpperBound() const {
return AffineBound(*this, lbMap->getNumInputs(), getNumOperands(), ubMap);
}
void ForStmt::setLowerBound(ArrayRef<MLValue *> operands, AffineMap *map) {
// TODO: handle the case when number of existing or new operands is non-zero.
assert(getNumOperands() == 0 && operands.empty());
this->lbMap = map;
}
void ForStmt::setUpperBound(ArrayRef<MLValue *> operands, AffineMap *map) {
// TODO: handle the case when number of existing or new operands is non-zero.
assert(getNumOperands() == 0 && operands.empty());
this->ubMap = map;
}
void ForStmt::setLowerBoundMap(AffineMap *map) {
assert(lbMap->getNumDims() == map->getNumDims() &&
lbMap->getNumSymbols() == map->getNumSymbols());
this->lbMap = map;
}
void ForStmt::setUpperBoundMap(AffineMap *map) {
assert(ubMap->getNumDims() == map->getNumDims() &&
ubMap->getNumSymbols() == map->getNumSymbols());
this->ubMap = map;
}
bool ForStmt::hasConstantLowerBound() const {
return lbMap->isSingleConstant();
}
bool ForStmt::hasConstantUpperBound() const {
return ubMap->isSingleConstant();
}
int64_t ForStmt::getConstantLowerBound() const {
return lbMap->getSingleConstantResult();
}
int64_t ForStmt::getConstantUpperBound() const {
return ubMap->getSingleConstantResult();
}
void ForStmt::setConstantLowerBound(int64_t value) {
setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
}
void ForStmt::setConstantUpperBound(int64_t value) {
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
}
ForStmt::operand_range ForStmt::getLowerBoundOperands() {
return {operand_begin(),
operand_begin() + getLowerBoundMap()->getNumInputs()};
}
ForStmt::operand_range ForStmt::getUpperBoundOperands() {
return {operand_begin() + getLowerBoundMap()->getNumInputs(), operand_end()};
}
bool ForStmt::matchingBoundOperandList() const {
if (lbMap->getNumDims() != ubMap->getNumDims() ||
lbMap->getNumSymbols() != ubMap->getNumSymbols())
return false;
unsigned numOperands = lbMap->getNumInputs();
for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) {
// Compare MLValue *'s.
if (getOperand(i) != getOperand(numOperands + i))
return false;
}
return true;
}
//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//
IfStmt::IfStmt(Location *location, unsigned numOperands, IntegerSet *set)
: Statement(Kind::If, location), thenClause(this), elseClause(nullptr),
set(set) {
operands.reserve(numOperands);
}
IfStmt::~IfStmt() {
if (elseClause)
delete elseClause;
// An IfStmt's IntegerSet 'set' should not be deleted since it is
// allocated through MLIRContext's bump pointer allocator.
}
IfStmt *IfStmt::create(Location *location, ArrayRef<MLValue *> operands,
IntegerSet *set) {
unsigned numOperands = operands.size();
assert(numOperands == set->getNumOperands() &&
"operand cound does not match the integer set operand count");
IfStmt *stmt = new IfStmt(location, numOperands, set);
for (auto *op : operands)
stmt->operands.emplace_back(StmtOperand(stmt, op));
return stmt;
}
const AffineCondition IfStmt::getCondition() const {
return AffineCondition(*this, set);
}
MLIRContext *IfStmt::getContext() const {
// Check for degenerate case of if statement with no operands.
// This is unlikely, but legal.
if (operands.empty())
return findFunction()->getContext();
return getOperand(0)->getType()->getContext();
}
//===----------------------------------------------------------------------===//
// Statement Cloning
//===----------------------------------------------------------------------===//
/// Create a deep copy of this statement, remapping any operands that use
/// values outside of the statement using the map that is provided (leaving
/// them alone if no entry is present). Replaces references to cloned
/// sub-statements to the corresponding statement that is copied, and adds
/// those mappings to the map.
Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
MLIRContext *context) const {
// If the specified value is in operandMap, return the remapped value.
// Otherwise return the value itself.
auto remapOperand = [&](const MLValue *value) -> MLValue * {
auto it = operandMap.find(value);
return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
};
SmallVector<MLValue *, 8> operands;
operands.reserve(getNumOperands());
for (auto *opValue : getOperands())
operands.push_back(remapOperand(opValue));
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
SmallVector<Type *, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
auto *newOp =
OperationStmt::create(getLoc(), opStmt->getName(), operands,
resultTypes, opStmt->getAttrs(), context);
// Remember the mapping of any results.
for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
operandMap[opStmt->getResult(i)] = newOp->getResult(i);
return newOp;
}
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
auto *lbMap = forStmt->getLowerBoundMap();
auto *ubMap = forStmt->getUpperBoundMap();
auto *newFor = ForStmt::create(
getLoc(),
ArrayRef<MLValue *>(operands).take_front(lbMap->getNumInputs()), lbMap,
ArrayRef<MLValue *>(operands).take_back(ubMap->getNumInputs()), ubMap,
forStmt->getStep(), context);
// Remember the induction variable mapping.
operandMap[forStmt] = newFor;
// Recursively clone the body of the for loop.
for (auto &subStmt : *forStmt)
newFor->push_back(subStmt.clone(operandMap, context));
return newFor;
}
// Otherwise, we must have an If statement.
auto *ifStmt = cast<IfStmt>(this);
auto *newIf = IfStmt::create(getLoc(), operands, ifStmt->getIntegerSet());
auto *resultThen = newIf->getThen();
for (auto &childStmt : *ifStmt->getThen())
resultThen->push_back(childStmt.clone(operandMap, context));
if (ifStmt->hasElse()) {
auto *resultElse = newIf->createElse();
for (auto &childStmt : *ifStmt->getElse())
resultElse->push_back(childStmt.clone(operandMap, context));
}
return newIf;
}