Basic representation and parsing of if and for statements. Loop headers and if statement conditions are not yet supported.

PiperOrigin-RevId: 203211526
This commit is contained in:
Tatiana Shpeisman 2018-07-03 17:51:28 -07:00 committed by jpienaar
parent 2057b454dc
commit 177ce7215c
10 changed files with 389 additions and 103 deletions

View File

@ -23,7 +23,7 @@
#define MLIR_IR_MLFUNCTION_H_
#include "mlir/IR/Function.h"
#include "mlir/IR/MLStatements.h"
#include "mlir/IR/Statements.h"
#include <vector>
namespace mlir {
@ -35,7 +35,7 @@ public:
MLFunction(StringRef name, FunctionType *type);
// FIXME: wrong representation and API, leaks memory etc
std::vector<MLStatement*> stmtList;
std::vector<Statement*> stmtList;
// TODO: add function arguments and return values once
// SSA values are implemented

View File

@ -1,58 +0,0 @@
//===- MLStatements.h - MLIR ML Statement Classes ------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the classes for MLFunction statements.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_MLSTATEMENTS_H
#define MLIR_IR_MLSTATEMENTS_H
#include "mlir/Support/LLVM.h"
namespace mlir {
class MLFunction;
/// ML function consists of ML statements - for statement, if statement
/// or operation.
class MLStatement {
public:
enum class Kind {
For,
If,
Operation
};
Kind getKind() const { return kind; }
/// Returns the function that this MLStatement is part of.
MLFunction *getFunction() const { return function; }
void print(raw_ostream &os) const;
void dump() const;
protected:
MLStatement(Kind kind, MLFunction *function)
: kind(kind), function(function) {}
private:
Kind kind;
MLFunction *function;
};
} //end namespace mlir
#endif // MLIR_IR_STATEMENTS_H

View File

@ -0,0 +1,130 @@
//===- Statements.h - MLIR ML Statement Classes ------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the classes for MLFunction statements.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_STATEMENTS_H
#define MLIR_IR_STATEMENTS_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/PointerUnion.h"
#include <vector>
namespace mlir {
class MLFunction;
class NodeStmt;
class ElseClause;
typedef PointerUnion<MLFunction *, NodeStmt *> ParentType;
/// Statement is a basic unit of execution within an ML function.
/// Statements can be nested within each other, effectively forming a tree.
class Statement {
public:
enum class Kind {
Operation,
For,
If,
Else
};
Kind getKind() const { return kind; }
/// Returns the parent of this statement. The parent of a nested statement
/// is the closest surrounding for or if statement. The parent of
/// a top-level statement is the function that contains the statement.
ParentType getParent() const { return parent; }
/// Returns the function that this statement is part of.
MLFunction *getFunction() const;
void print(raw_ostream &os) const;
void dump() const;
protected:
Statement(Kind kind, ParentType parent) : kind(kind), parent(parent) {}
private:
Kind kind;
ParentType parent;
};
/// Node statement represents a statement that may contain other statements.
class NodeStmt : public Statement {
public:
// FIXME: wrong representation and API, leaks memory etc
std::vector<Statement*> children;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() != Kind::Operation;
}
protected:
NodeStmt(Kind kind, ParentType parent) : Statement(kind, parent) {}
};
/// For statement represents an affine loop nest.
class ForStmt : public NodeStmt {
public:
explicit ForStmt(ParentType parent) : NodeStmt(Kind::For, parent) {}
// TODO: represent loop variable, bounds and step
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::For;
}
};
/// If statement restricts execution to a subset of the loop iteration space.
class IfStmt : public NodeStmt {
public:
explicit IfStmt(ParentType parent) : NodeStmt(Kind::If, parent) {}
// TODO: Represent condition
// FIXME: most likely wrong representation since it's wrong everywhere else
std::vector<ElseClause *> elseClauses;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::If;
}
};
/// Else clause reprsents else or else-if clause of an if statement
class ElseClause : public NodeStmt {
public:
explicit ElseClause(IfStmt *ifStmt, int clauseNum);
// TODO: Represent optional condition
// Returns ordinal number of this clause in the list of clauses.
int getClauseNumber() const { return clauseNum;}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::Else;
}
private:
int clauseNum;
};
} //end namespace mlir
#endif // MLIR_IR_STATEMENTS_H

View File

@ -151,9 +151,96 @@ void CFGFunctionState::print(const ReturnInst *inst) {
}
//===----------------------------------------------------------------------===//
// print and dump methods
// ML Function printing
//===----------------------------------------------------------------------===//
namespace {
class MLFunctionState {
public:
MLFunctionState(const MLFunction *function, raw_ostream &os);
const MLFunction *getFunction() const { return function; }
void print();
void print(const Statement *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const ElseClause *stmt, bool isLast);
private:
// Print statements nested within this node statement.
void printNestedStatements(const NodeStmt *stmt);
const MLFunction *function;
raw_ostream &os;
int numSpaces;
};
} // end anonymous namespace
MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
: function(function), os(os), numSpaces(2) {}
void MLFunctionState::print() {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
printFunctionSignature(function, os);
os << " {\n";
for (auto *stmt : function->stmtList)
print(stmt);
os << " return\n";
os << "}\n\n";
}
void MLFunctionState::print(const Statement *stmt) {
os.indent(numSpaces);
switch (stmt->getKind()) {
case Statement::Kind::Operation: // TODO
assert(0 && "Operation statement is not yet implemented");
case Statement::Kind::For:
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
return print(cast<IfStmt>(stmt));
case Statement::Kind::Else:
return print(cast<ElseClause>(stmt));
}
}
void MLFunctionState::printNestedStatements(const NodeStmt *stmt) {
os << "{\n";
numSpaces += 2;
for (auto * nestedStmt : stmt->children)
print(nestedStmt);
numSpaces -= 2;
os.indent(numSpaces) << "}";
}
void MLFunctionState::print(const ForStmt *stmt) {
os << "for ";
printNestedStatements(stmt);
os << "\n";
}
void MLFunctionState::print(const IfStmt *stmt) {
os << "if ";
printNestedStatements(stmt);
int numClauses = stmt->elseClauses.size();
for (auto e : stmt->elseClauses)
print(e, e->getClauseNumber() == numClauses - 1);
os << "\n";
}
void MLFunctionState::print(const ElseClause *stmt, bool isLast) {
if (!isLast)
os << " if";
os << " else ";;
printNestedStatements(stmt);
}
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
void Instruction::print(raw_ostream &os) const {
CFGFunctionState state(getFunction(), os);
@ -182,11 +269,12 @@ void BasicBlock::dump() const {
print(llvm::errs());
}
void MLStatement::print(raw_ostream &os) const {
//TODO
void Statement::print(raw_ostream &os) const {
MLFunctionState state(getFunction(), os);
state.print(this);
}
void MLStatement::dump() const {
void Statement::dump() const {
print(llvm::errs());
}
void Function::print(raw_ostream &os) const {
@ -207,15 +295,8 @@ void CFGFunction::print(raw_ostream &os) const {
}
void MLFunction::print(raw_ostream &os) const {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
printFunctionSignature(this, os);
os << " {\n";
for (auto *stmt : stmtList)
stmt->print(os);
os << " return\n";
os << "}\n\n";
MLFunctionState state(this, os);
state.print();
}
void Module::print(raw_ostream &os) const {

View File

@ -1,22 +0,0 @@
//===- MLStatements.cpp - MLIR MLStatement Instruction Classes ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLStatements.h"
using namespace mlir;
// TODO: classes derived from MLStatement

View File

@ -0,0 +1,40 @@
//===- Statements.cpp - MLIR Statement Instruction Classes ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Statement
//===----------------------------------------------------------------------===//
MLFunction *Statement::getFunction() const {
ParentType p = parent;
while (!p.is<MLFunction *>())
p = p.get<NodeStmt *>()->getParent();
return p.get<MLFunction *>();
}
//===----------------------------------------------------------------------===//
// ElseClause
//===----------------------------------------------------------------------===//
ElseClause::ElseClause(IfStmt *ifStmt, int clauseNum)
: NodeStmt(Kind::Else, ifStmt), clauseNum(clauseNum) {
ifStmt->elseClauses.push_back(this);
}

View File

@ -149,11 +149,14 @@ private:
ParseResult parseCFGFunc();
ParseResult parseMLFunc();
ParseResult parseBasicBlock(CFGFunctionParserState &functionState);
MLStatement *parseMLStatement(MLFunction *currentFunction);
Statement *parseStatement(ParentType parent);
OperationInst *parseCFGOperation(CFGFunctionParserState &functionState);
TerminatorInst *parseTerminator(CFGFunctionParserState &functionState);
ForStmt *parseForStmt(ParentType parent);
IfStmt *parseIfStmt(ParentType parent);
ParseResult parseNestedStatements(NodeStmt *parent);
};
} // end anonymous namespace
@ -976,7 +979,7 @@ ParseResult Parser::parseMLFunc() {
// Parse the list of instructions.
while (!consumeIf(Token::kw_return)) {
auto *stmt = parseMLStatement(function);
auto *stmt = parseStatement(function);
if (!stmt)
return ParseFailure;
function->stmtList.push_back(stmt);
@ -991,18 +994,98 @@ ParseResult Parser::parseMLFunc() {
return ParseSuccess;
}
/// Parse an MLStatement
/// TODO
/// Statement.
///
MLStatement *Parser::parseMLStatement(MLFunction *currentFunction) {
/// ml-stmt ::= instruction | ml-for-stmt | ml-if-stmt
/// TODO: fix terminology in MLSpec document. ML functions
/// contain operation statements, not instructions.
///
Statement * Parser::parseStatement(ParentType parent) {
switch (curToken.getKind()) {
default:
return (emitError("expected ML statement"), nullptr);
//TODO: parse OperationStmt
return (emitError("expected statement"), nullptr);
// TODO: add parsing of ML statements
case Token::kw_for:
return parseForStmt(parent);
case Token::kw_if:
return parseIfStmt(parent);
}
}
/// For statement.
///
/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
/// (`step` integer-literal)? `{` ml-stmt* `}`
///
ForStmt * Parser::parseForStmt(ParentType parent) {
consumeToken(Token::kw_for);
//TODO: parse loop header
ForStmt *stmt = new ForStmt(parent);
if (parseNestedStatements(stmt)) {
delete stmt;
return nullptr;
}
return stmt;
}
/// If statement.
///
/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
/// | ml-if-head `else` `if` ml-if-cond `{` ml-stmt* `}`
/// ml-if-stmt ::= ml-if-head
/// | ml-if-head `else` `{` ml-stmt* `}`
///
IfStmt * Parser::parseIfStmt(PointerUnion<MLFunction *, NodeStmt *> parent) {
consumeToken(Token::kw_if);
//TODO: parse condition
IfStmt *stmt = new IfStmt(parent);
if (parseNestedStatements(stmt)) {
delete stmt;
return nullptr;
}
int clauseNum = 0;
while (consumeIf(Token::kw_else)) {
if (consumeIf(Token::kw_if)) {
//TODO: parse condition
}
ElseClause * clause = new ElseClause(stmt, clauseNum);
++clauseNum;
if (parseNestedStatements(clause)) {
delete clause;
return nullptr;
}
}
return stmt;
}
///
/// Parse `{` ml-stmt* `}`
///
ParseResult Parser::parseNestedStatements(NodeStmt *parent) {
if (!consumeIf(Token::l_brace))
return emitError("expected '{' before statement list");
if (consumeIf(Token::r_brace)) {
// TODO: parse OperationStmt
return ParseSuccess;
}
while (!consumeIf(Token::r_brace)) {
auto *stmt = parseStatement(parent);
if (!stmt)
return ParseFailure;
parent->children.push_back(stmt);
}
return ParseSuccess;
}
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
//===----------------------------------------------------------------------===//

View File

@ -89,10 +89,13 @@ TOK_KEYWORD(affineint)
TOK_KEYWORD(bf16)
TOK_KEYWORD(br)
TOK_KEYWORD(cfgfunc)
TOK_KEYWORD(else)
TOK_KEYWORD(extfunc)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)
TOK_KEYWORD(f64)
TOK_KEYWORD(for)
TOK_KEYWORD(if)
TOK_KEYWORD(memref)
TOK_KEYWORD(mlfunc)
TOK_KEYWORD(return)

View File

@ -80,4 +80,14 @@ bb40:
extfunc @illegaltype(i0) ; expected-error {{invalid integer width}}
; -----
mlfunc @incomplete_for() {
for
} ; expected-error {{expected '{' before statement list}}
; -----
mlfunc @non_statement() {
asd ; expected-error {{expected statement}}
}

View File

@ -56,6 +56,25 @@ bb4: ; CHECK: bb3:
; CHECK-LABEL: mlfunc @simpleMLF() {
mlfunc @simpleMLF() {
return ; CHECK: return
return ; CHECK: return
} ; CHECK: }
; CHECK-LABEL: mlfunc @loops() {
mlfunc @loops() {
for { ; CHECK: for {
for { ; CHECK: for {
} ; CHECK: }
} ; CHECK: }
return ; CHECK: return
} ; CHECK: }
; CHECK-LABEL: mlfunc @ifstmt() {
mlfunc @ifstmt() {
for { ; CHECK for {
if { ; CHECK if {
} else if { ; CHECK } else if {
} else { ; CHECK } else {
} ; CHECK }
} ; CHECK }
return ; CHECK return
} ; CHECK }