From 3609599af69c9c091b75d0caefadfb5a0479c913 Mon Sep 17 00:00:00 2001 From: Tatiana Shpeisman Date: Thu, 28 Jun 2018 17:02:32 -0700 Subject: [PATCH] Introduce IR and parser support for ML functions. Representing function arguments is still TODO. Supporting instructions other than return is also TODO. PiperOrigin-RevId: 202570934 --- mlir/include/mlir/IR/Function.h | 2 +- mlir/include/mlir/IR/MLFunction.h | 53 ++++++++++++++++++++++++ mlir/include/mlir/IR/MLStatements.h | 58 ++++++++++++++++++++++++++ mlir/lib/IR/AsmPrinter.cpp | 21 ++++++++++ mlir/lib/IR/Function.cpp | 9 +++++ mlir/lib/IR/MLStatements.cpp | 22 ++++++++++ mlir/lib/Parser/Parser.cpp | 63 ++++++++++++++++++++++++++++- mlir/test/IR/parser-errors.mlir | 10 +++++ mlir/test/IR/parser.mlir | 6 +++ 9 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 mlir/include/mlir/IR/MLFunction.h create mode 100644 mlir/include/mlir/IR/MLStatements.h create mode 100644 mlir/lib/IR/MLStatements.cpp diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index c94493798877..3655c9cc8b85 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -33,7 +33,7 @@ namespace mlir { class Function { public: enum class Kind { - ExtFunc, CFGFunc + ExtFunc, CFGFunc, MLFunc }; Kind getKind() const { return kind; } diff --git a/mlir/include/mlir/IR/MLFunction.h b/mlir/include/mlir/IR/MLFunction.h new file mode 100644 index 000000000000..8e45407ee326 --- /dev/null +++ b/mlir/include/mlir/IR/MLFunction.h @@ -0,0 +1,53 @@ +//===- MLFunction.h - MLIR MLFunction Class -------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines MLFunction class +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_MLFUNCTION_H_ +#define MLIR_IR_MLFUNCTION_H_ + +#include "mlir/IR/Function.h" +#include "mlir/IR/MLStatements.h" +#include + +namespace mlir { + +// MLFunction is defined as a sequence of statements that may +// include nested affine for loops, conditionals and instructions. +class MLFunction : public Function { +public: + MLFunction(StringRef name, FunctionType *type); + + // FIXME: wrong representation and API, leaks memory etc + std::vector stmtList; + + // TODO: add function arguments and return values once + // SSA values are implemented + + // Methods for support type inquiry through isa, cast, and dyn_cast + static bool classof(const Function *func) { + return func->getKind() == Kind::MLFunc; + } + + void print(raw_ostream &os) const; +}; + +} // end namespace mlir + +#endif // MLIR_IR_MLFUNCTION_H_ diff --git a/mlir/include/mlir/IR/MLStatements.h b/mlir/include/mlir/IR/MLStatements.h new file mode 100644 index 000000000000..b5658900e955 --- /dev/null +++ b/mlir/include/mlir/IR/MLStatements.h @@ -0,0 +1,58 @@ +//===- MLStatements.h - MLIR ML Statement Classes ------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the classes for MLFunction statements. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_MLSTATEMENTS_H +#define MLIR_IR_MLSTATEMENTS_H + +#include "mlir/Support/LLVM.h" + +namespace mlir { + class MLFunction; + +/// ML function consists of ML statements - for statement, if statement +/// or operation. +class MLStatement { +public: + enum class Kind { + For, + If, + Operation + }; + + Kind getKind() const { return kind; } + + /// Returns the function that this MLStatement is part of. + MLFunction *getFunction() const { return function; } + + void print(raw_ostream &os) const; + void dump() const; + +protected: + MLStatement(Kind kind, MLFunction *function) + : kind(kind), function(function) {} + +private: + Kind kind; + MLFunction *function; +}; + +} //end namespace mlir +#endif // MLIR_IR_STATEMENTS_H diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 08fe8380748f..03871fc52b96 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -21,6 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/CFGFunction.h" +#include "mlir/IR/MLFunction.h" #include "mlir/IR/Module.h" #include "mlir/IR/Types.h" #include "mlir/Support/STLExtras.h" @@ -151,10 +152,18 @@ void BasicBlock::dump() const { print(llvm::errs()); } +void MLStatement::print(raw_ostream &os) const { + //TODO +} + +void MLStatement::dump() const { + print(llvm::errs()); +} void Function::print(raw_ostream &os) const { switch (getKind()) { case Kind::ExtFunc: return cast(this)->print(os); case Kind::CFGFunc: return cast(this)->print(os); + case Kind::MLFunc: return cast(this)->print(os); } } @@ -167,6 +176,18 @@ void CFGFunction::print(raw_ostream &os) const { state.print(); } +void MLFunction::print(raw_ostream &os) const { + os << "mlfunc "; + // FIXME: should print argument names rather than just signature + printFunctionSignature(this, os); + os << " {\n"; + + for (auto *stmt : stmtList) + stmt->print(os); + os << " return\n"; + os << "}\n\n"; +} + void Module::print(raw_ostream &os) const { for (auto *fn : functionList) fn->print(os); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 833c17323c61..3af2a61e7d1f 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/IR/CFGFunction.h" +#include "mlir/IR/MLFunction.h" #include "llvm/ADT/StringRef.h" using namespace mlir; @@ -38,3 +39,11 @@ ExtFunction::ExtFunction(StringRef name, FunctionType *type) CFGFunction::CFGFunction(StringRef name, FunctionType *type) : Function(name, type, Kind::CFGFunc) { } + +//===----------------------------------------------------------------------===// +// MLFunction implementation. +//===----------------------------------------------------------------------===// + +MLFunction::MLFunction(StringRef name, FunctionType *type) + : Function(name, type, Kind::MLFunc) { +} diff --git a/mlir/lib/IR/MLStatements.cpp b/mlir/lib/IR/MLStatements.cpp new file mode 100644 index 000000000000..26fa7080508e --- /dev/null +++ b/mlir/lib/IR/MLStatements.cpp @@ -0,0 +1,22 @@ +//===- MLStatements.cpp - MLIR MLStatement Instruction Classes ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/MLFunction.h" +#include "mlir/IR/MLStatements.h" +using namespace mlir; + +// TODO: classes derived from MLStatement diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 0d2497298813..df952f95ea9f 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Module.h" #include "mlir/IR/CFGFunction.h" +#include "mlir/IR/MLFunction.h" #include "mlir/IR/Types.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; @@ -134,9 +135,11 @@ private: ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type); ParseResult parseExtFunc(); ParseResult parseCFGFunc(); + ParseResult parseMLFunc(); ParseResult parseBasicBlock(CFGFunctionParserState &functionState); TerminatorInst *parseTerminator(BasicBlock *currentBB, CFGFunctionParserState &functionState); + MLStatement *parseMLStatement(MLFunction *currentFunction); }; } // end anonymous namespace @@ -532,7 +535,6 @@ ParseResult Parser::parseFunctionSignature(StringRef &name, return ParseSuccess; } - /// External function declarations. /// /// ext-func ::= `extfunc` function-signature @@ -707,6 +709,59 @@ TerminatorInst *Parser::parseTerminator(BasicBlock *currentBB, } } +/// ML function declarations. +/// +/// ml-func ::= `mlfunc` ml-func-signature `{` ml-stmt* ml-return-stmt `}` +/// +ParseResult Parser::parseMLFunc() { + consumeToken(Token::kw_mlfunc); + + StringRef name; + FunctionType *type = nullptr; + + // FIXME: Parse ML function signature (args + types) + // by passing pointer to SmallVector into parseFunctionSignature + if (parseFunctionSignature(name, type)) + return ParseFailure; + + if (!consumeIf(Token::l_brace)) + return emitError("expected '{' in ML function"); + + // Okay, the ML function signature was parsed correctly, create the function. + auto function = new MLFunction(name, type); + + // Make sure we have at least one statement. + if (curToken.is(Token::r_brace)) + return emitError("ML function must end with return statement"); + + // Parse the list of instructions. + while (!consumeIf(Token::kw_return)) { + auto *stmt = parseMLStatement(function); + if (!stmt) + return ParseFailure; + function->stmtList.push_back(stmt); + } + + // TODO: parse return statement operands + if (!consumeIf(Token::r_brace)) + emitError("expected '}' in ML function"); + + module->functionList.push_back(function); + + return ParseSuccess; +} + +/// Parse an MLStatement +/// TODO +/// +MLStatement *Parser::parseMLStatement(MLFunction *currentFunction) { + switch (curToken.getKind()) { + default: + return (emitError("expected ML statement"), nullptr); + + // TODO: add parsing of ML statements + } +} //===----------------------------------------------------------------------===// // Top-level entity parsing. @@ -741,7 +796,11 @@ Module *Parser::parseModule() { if (parseAffineMapDef()) return nullptr; break; - // TODO: mlfunc, affine entity declarations, etc. + case Token::kw_mlfunc: + if (parseMLFunc()) return nullptr; + break; + + // TODO: affine entity declarations, etc. } } } diff --git a/mlir/test/IR/parser-errors.mlir b/mlir/test/IR/parser-errors.mlir index d742b6b0cfe3..e60bf78877d8 100644 --- a/mlir/test/IR/parser-errors.mlir +++ b/mlir/test/IR/parser-errors.mlir @@ -47,3 +47,13 @@ bb41: bb42: ; expected-error {{expected terminator}} return } + +; ----- + +mlfunc @foo() +mlfunc @bar() ; expected-error {{expected '{' in ML function}} + +; ----- + +mlfunc @no_return() { +} ; expected-error {{ML function must end with return statement}} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index e307a3a45c43..b7c28f778acb 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -48,3 +48,9 @@ bb2: ; CHECK: bb2: bb4: ; CHECK: bb3: return ; CHECK: return } ; CHECK: } + +; CHECK-LABEL: mlfunc @simpleMLF() { +mlfunc @simpleMLF() { + return ; CHECK: return +} ; CHECK: } +