Start a Linalg dialect

This CL starts implementing a Linalg dialect with the objective of supporting
    optimizing compilation of loops and library calls for a subset of common linear
    algebra operations.

    This CL starts by simply adding a linalg.range type and an operation with the
    proper roundtripping test.

--

PiperOrigin-RevId: 244189468
This commit is contained in:
Nicolas Vasilache 2019-04-18 08:25:54 -07:00 committed by Mehdi Amini
parent 05dfb1c7e0
commit 8370cc7492
10 changed files with 276 additions and 2 deletions

View File

@ -23,9 +23,9 @@
namespace linalg {
enum LinalgTypes {
Range = mlir::Type::FIRST_LINALG_TYPE,
Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
View,
LAST_USED_LINALG_TYPE = View,
FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View,
};
} // namespace linalg

View File

@ -0,0 +1,53 @@
//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_LINALG_LINALGOPS_H_
#define MLIR_LINALG_LINALGOPS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
/// A RangeOp is used to create a value of RangeType from 3 values of type index
/// that represent the min, max and step values of the range.
class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
using Op::Op;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
//////////////////////////////////////////////////////////////////////////////
static llvm::StringRef getOperationName() { return "linalg.range"; }
static void build(Builder *b, OperationState *result, Value *min, Value *max,
Value *step);
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
Value *min() { return getOperand(0); }
Value *max() { return getOperand(1); }
Value *step() { return getOperand(2); }
};
} // namespace mlir
#endif // MLIR_LINALG_LINALGOPS_H_

View File

@ -0,0 +1,59 @@
//===- LinalgTypes.h - Linalg Types ---------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_LINALG_LINALGTYPES_H_
#define MLIR_LINALG_LINALGTYPES_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"
namespace mlir {
class MLIRContext;
enum LinalgTypes {
Range = Type::FIRST_LINALG_TYPE,
LAST_USED_LINALG_TYPE = Range,
};
class LinalgDialect : public Dialect {
public:
explicit LinalgDialect(MLIRContext *context);
/// Parse a type registered to this dialect.
Type parseType(llvm::StringRef spec, Location loc) const override;
/// Print a type registered to this dialect.
void printType(Type type, llvm::raw_ostream &os) const override;
};
/// A RangeType represents a minimal range abstraction (min, max, step).
class RangeType : public Type::TypeBase<RangeType, Type> {
public:
// Used for generic hooks in TypeBase.
using Base::Base;
/// Construction hook.
static RangeType get(MLIRContext *context) {
/// Custom, uniq'ed construction in the MLIRContext.
return Base::get(context, LinalgTypes::Range);
}
/// Used to implement llvm-style cast.
static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
};
} // namespace mlir
#endif // MLIR_LINALG_LINALGTYPES_H_

View File

@ -6,6 +6,7 @@ add_subdirectory(ExecutionEngine)
add_subdirectory(FxpMathOps)
add_subdirectory(IR)
add_subdirectory(LLVMIR)
add_subdirectory(Linalg)
add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Quantization)

View File

@ -0,0 +1,8 @@
add_llvm_library(MLIRLinalg
LinalgOps.cpp
LinalgRegistration.cpp
LinalgTypes.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg
)

View File

@ -0,0 +1,67 @@
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a the Linalg operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Linalg/LinalgOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/LinalgTypes.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
Value *max, Value *step) {
result->addOperands({min, max, step});
result->addTypes({RangeType::get(b->getContext())});
}
// Verification is simply that a RangeOp takes 3 index ssa-value.
mlir::LogicalResult mlir::RangeOp::verify() {
if (!min() || !min()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!max() || !max()->getType().isa<IndexType>())
return emitOpError("second operand should be of type index");
if (!step() || !step()->getType().isa<IndexType>())
return emitOpError("third operand should be of type index");
return mlir::success();
}
// A RangeOp prints as:
//
// ```{.mlir}
// linalg.range %0:%1:%2 : !linalg.range
// ```
void mlir::RangeOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
<< " : " << getType();
}
bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types);
}

View File

@ -0,0 +1,24 @@
//===- LinalgRegistration.cpp - Register the linalg dialect statically ----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Linalg/LinalgOps.h"
#include "mlir/Linalg/LinalgTypes.h"
using namespace mlir;
// Static initialization for LinalgOps dialect registration.
static DialectRegistration<LinalgDialect> LinalgOps;

View File

@ -0,0 +1,53 @@
//===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the Linalg dialect types and dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Linalg/LinalgTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Linalg/LinalgOps.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect("linalg", context) {
addTypes<RangeType>();
addOperations<RangeOp>();
}
Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
MLIRContext *context = getContext();
if (spec == "range")
return RangeType::get(getContext());
return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
}
/// RangeType prints as just "range".
static void print(RangeType rt, raw_ostream &os) { os << "range"; }
void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
case LinalgTypes::Range:
print(type.cast<RangeType>(), os);
break;
}
}

View File

@ -0,0 +1,8 @@
// RUN: mlir-opt %s -verify | mlir-opt | FileCheck %s
func @range(%arg0: index, %arg1: index, %arg2: index) {
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
return
}
// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range

View File

@ -3,6 +3,7 @@ set(LIBS
MLIRAnalysis
MLIREDSC
MLIRFxpMathOps
MLIRLinalg
MLIRLLVMIR
MLIRParser
MLIRPass