Introduce IR and parser support for ML functions.

Representing function arguments is still TODO.
Supporting instructions other than return is also TODO.

PiperOrigin-RevId: 202570934
This commit is contained in:
Tatiana Shpeisman 2018-06-28 17:02:32 -07:00 committed by jpienaar
parent 8901448f14
commit 3609599af6
9 changed files with 241 additions and 3 deletions

View File

@ -33,7 +33,7 @@ namespace mlir {
class Function {
public:
enum class Kind {
ExtFunc, CFGFunc
ExtFunc, CFGFunc, MLFunc
};
Kind getKind() const { return kind; }

View File

@ -0,0 +1,53 @@
//===- MLFunction.h - MLIR MLFunction Class -------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines MLFunction class
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_MLFUNCTION_H_
#define MLIR_IR_MLFUNCTION_H_
#include "mlir/IR/Function.h"
#include "mlir/IR/MLStatements.h"
#include <vector>
namespace mlir {
// MLFunction is defined as a sequence of statements that may
// include nested affine for loops, conditionals and instructions.
class MLFunction : public Function {
public:
MLFunction(StringRef name, FunctionType *type);
// FIXME: wrong representation and API, leaks memory etc
std::vector<MLStatement*> stmtList;
// TODO: add function arguments and return values once
// SSA values are implemented
// Methods for support type inquiry through isa, cast, and dyn_cast
static bool classof(const Function *func) {
return func->getKind() == Kind::MLFunc;
}
void print(raw_ostream &os) const;
};
} // end namespace mlir
#endif // MLIR_IR_MLFUNCTION_H_

View File

@ -0,0 +1,58 @@
//===- MLStatements.h - MLIR ML Statement Classes ------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the classes for MLFunction statements.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_MLSTATEMENTS_H
#define MLIR_IR_MLSTATEMENTS_H
#include "mlir/Support/LLVM.h"
namespace mlir {
class MLFunction;
/// ML function consists of ML statements - for statement, if statement
/// or operation.
class MLStatement {
public:
enum class Kind {
For,
If,
Operation
};
Kind getKind() const { return kind; }
/// Returns the function that this MLStatement is part of.
MLFunction *getFunction() const { return function; }
void print(raw_ostream &os) const;
void dump() const;
protected:
MLStatement(Kind kind, MLFunction *function)
: kind(kind), function(function) {}
private:
Kind kind;
MLFunction *function;
};
} //end namespace mlir
#endif // MLIR_IR_STATEMENTS_H

View File

@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
@ -151,10 +152,18 @@ void BasicBlock::dump() const {
print(llvm::errs());
}
void MLStatement::print(raw_ostream &os) const {
//TODO
}
void MLStatement::dump() const {
print(llvm::errs());
}
void Function::print(raw_ostream &os) const {
switch (getKind()) {
case Kind::ExtFunc: return cast<ExtFunction>(this)->print(os);
case Kind::CFGFunc: return cast<CFGFunction>(this)->print(os);
case Kind::MLFunc: return cast<MLFunction>(this)->print(os);
}
}
@ -167,6 +176,18 @@ void CFGFunction::print(raw_ostream &os) const {
state.print();
}
void MLFunction::print(raw_ostream &os) const {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
printFunctionSignature(this, os);
os << " {\n";
for (auto *stmt : stmtList)
stmt->print(os);
os << " return\n";
os << "}\n\n";
}
void Module::print(raw_ostream &os) const {
for (auto *fn : functionList)
fn->print(os);

View File

@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "llvm/ADT/StringRef.h"
using namespace mlir;
@ -38,3 +39,11 @@ ExtFunction::ExtFunction(StringRef name, FunctionType *type)
CFGFunction::CFGFunction(StringRef name, FunctionType *type)
: Function(name, type, Kind::CFGFunc) {
}
//===----------------------------------------------------------------------===//
// MLFunction implementation.
//===----------------------------------------------------------------------===//
MLFunction::MLFunction(StringRef name, FunctionType *type)
: Function(name, type, Kind::MLFunc) {
}

View File

@ -0,0 +1,22 @@
//===- MLStatements.cpp - MLIR MLStatement Instruction Classes ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLStatements.h"
using namespace mlir;
// TODO: classes derived from MLStatement

View File

@ -24,6 +24,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
@ -134,9 +135,11 @@ private:
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type);
ParseResult parseExtFunc();
ParseResult parseCFGFunc();
ParseResult parseMLFunc();
ParseResult parseBasicBlock(CFGFunctionParserState &functionState);
TerminatorInst *parseTerminator(BasicBlock *currentBB,
CFGFunctionParserState &functionState);
MLStatement *parseMLStatement(MLFunction *currentFunction);
};
} // end anonymous namespace
@ -532,7 +535,6 @@ ParseResult Parser::parseFunctionSignature(StringRef &name,
return ParseSuccess;
}
/// External function declarations.
///
/// ext-func ::= `extfunc` function-signature
@ -707,6 +709,59 @@ TerminatorInst *Parser::parseTerminator(BasicBlock *currentBB,
}
}
/// ML function declarations.
///
/// ml-func ::= `mlfunc` ml-func-signature `{` ml-stmt* ml-return-stmt `}`
///
ParseResult Parser::parseMLFunc() {
consumeToken(Token::kw_mlfunc);
StringRef name;
FunctionType *type = nullptr;
// FIXME: Parse ML function signature (args + types)
// by passing pointer to SmallVector<identifier> into parseFunctionSignature
if (parseFunctionSignature(name, type))
return ParseFailure;
if (!consumeIf(Token::l_brace))
return emitError("expected '{' in ML function");
// Okay, the ML function signature was parsed correctly, create the function.
auto function = new MLFunction(name, type);
// Make sure we have at least one statement.
if (curToken.is(Token::r_brace))
return emitError("ML function must end with return statement");
// Parse the list of instructions.
while (!consumeIf(Token::kw_return)) {
auto *stmt = parseMLStatement(function);
if (!stmt)
return ParseFailure;
function->stmtList.push_back(stmt);
}
// TODO: parse return statement operands
if (!consumeIf(Token::r_brace))
emitError("expected '}' in ML function");
module->functionList.push_back(function);
return ParseSuccess;
}
/// Parse an MLStatement
/// TODO
///
MLStatement *Parser::parseMLStatement(MLFunction *currentFunction) {
switch (curToken.getKind()) {
default:
return (emitError("expected ML statement"), nullptr);
// TODO: add parsing of ML statements
}
}
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
@ -741,7 +796,11 @@ Module *Parser::parseModule() {
if (parseAffineMapDef()) return nullptr;
break;
// TODO: mlfunc, affine entity declarations, etc.
case Token::kw_mlfunc:
if (parseMLFunc()) return nullptr;
break;
// TODO: affine entity declarations, etc.
}
}
}

View File

@ -47,3 +47,13 @@ bb41:
bb42: ; expected-error {{expected terminator}}
return
}
; -----
mlfunc @foo()
mlfunc @bar() ; expected-error {{expected '{' in ML function}}
; -----
mlfunc @no_return() {
} ; expected-error {{ML function must end with return statement}}

View File

@ -48,3 +48,9 @@ bb2: ; CHECK: bb2:
bb4: ; CHECK: bb3:
return ; CHECK: return
} ; CHECK: }
; CHECK-LABEL: mlfunc @simpleMLF() {
mlfunc @simpleMLF() {
return ; CHECK: return
} ; CHECK: }