From 8370cc74921f6fc84f5aaaf3f91bdf12890ed1f9 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 18 Apr 2019 08:25:54 -0700 Subject: [PATCH] Start a Linalg dialect This CL starts implementing a Linalg dialect with the objective of supporting optimizing compilation of loops and library calls for a subset of common linear algebra operations. This CL starts by simply adding a linalg.range type and an operation with the proper roundtripping test. -- PiperOrigin-RevId: 244189468 --- .../Linalg/Linalg1/include/linalg1/Types.h | 4 +- mlir/include/mlir/Linalg/LinalgOps.h | 53 +++++++++++++++ mlir/include/mlir/Linalg/LinalgTypes.h | 59 ++++++++++++++++ mlir/lib/CMakeLists.txt | 1 + mlir/lib/Linalg/CMakeLists.txt | 8 +++ mlir/lib/Linalg/LinalgOps.cpp | 67 +++++++++++++++++++ mlir/lib/Linalg/LinalgRegistration.cpp | 24 +++++++ mlir/lib/Linalg/LinalgTypes.cpp | 53 +++++++++++++++ mlir/test/Linalg/roundtrip.mlir | 8 +++ mlir/tools/mlir-opt/CMakeLists.txt | 1 + 10 files changed, 276 insertions(+), 2 deletions(-) create mode 100644 mlir/include/mlir/Linalg/LinalgOps.h create mode 100644 mlir/include/mlir/Linalg/LinalgTypes.h create mode 100644 mlir/lib/Linalg/CMakeLists.txt create mode 100644 mlir/lib/Linalg/LinalgOps.cpp create mode 100644 mlir/lib/Linalg/LinalgRegistration.cpp create mode 100644 mlir/lib/Linalg/LinalgTypes.cpp create mode 100644 mlir/test/Linalg/roundtrip.mlir diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Types.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Types.h index b2fa7fd26c89..5032e969c4b7 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Types.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Types.h @@ -23,9 +23,9 @@ namespace linalg { enum LinalgTypes { - Range = mlir::Type::FIRST_LINALG_TYPE, + Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE, View, - LAST_USED_LINALG_TYPE = View, + FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View, }; } // namespace linalg diff --git a/mlir/include/mlir/Linalg/LinalgOps.h b/mlir/include/mlir/Linalg/LinalgOps.h new file mode 100644 index 000000000000..7921822d8326 --- /dev/null +++ b/mlir/include/mlir/Linalg/LinalgOps.h @@ -0,0 +1,53 @@ +//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_LINALG_LINALGOPS_H_ +#define MLIR_LINALG_LINALGOPS_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { + +/// A RangeOp is used to create a value of RangeType from 3 values of type index +/// that represent the min, max and step values of the range. +class RangeOp : public Op::Impl, + OpTrait::OneResult, OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.range"; } + static void build(Builder *b, OperationState *result, Value *min, Value *max, + Value *step); + LogicalResult verify(); + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + Value *min() { return getOperand(0); } + Value *max() { return getOperand(1); } + Value *step() { return getOperand(2); } +}; + +} // namespace mlir + +#endif // MLIR_LINALG_LINALGOPS_H_ diff --git a/mlir/include/mlir/Linalg/LinalgTypes.h b/mlir/include/mlir/Linalg/LinalgTypes.h new file mode 100644 index 000000000000..2d2c74eb7af7 --- /dev/null +++ b/mlir/include/mlir/Linalg/LinalgTypes.h @@ -0,0 +1,59 @@ +//===- LinalgTypes.h - Linalg Types ---------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_LINALG_LINALGTYPES_H_ +#define MLIR_LINALG_LINALGTYPES_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Types.h" + +namespace mlir { +class MLIRContext; + +enum LinalgTypes { + Range = Type::FIRST_LINALG_TYPE, + LAST_USED_LINALG_TYPE = Range, +}; + +class LinalgDialect : public Dialect { +public: + explicit LinalgDialect(MLIRContext *context); + + /// Parse a type registered to this dialect. + Type parseType(llvm::StringRef spec, Location loc) const override; + + /// Print a type registered to this dialect. + void printType(Type type, llvm::raw_ostream &os) const override; +}; + +/// A RangeType represents a minimal range abstraction (min, max, step). +class RangeType : public Type::TypeBase { +public: + // Used for generic hooks in TypeBase. + using Base::Base; + /// Construction hook. + static RangeType get(MLIRContext *context) { + /// Custom, uniq'ed construction in the MLIRContext. + return Base::get(context, LinalgTypes::Range); + } + /// Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; } +}; + +} // namespace mlir + +#endif // MLIR_LINALG_LINALGTYPES_H_ diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index 3d05cd1ab09f..920cf7955bc0 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(ExecutionEngine) add_subdirectory(FxpMathOps) add_subdirectory(IR) add_subdirectory(LLVMIR) +add_subdirectory(Linalg) add_subdirectory(Parser) add_subdirectory(Pass) add_subdirectory(Quantization) diff --git a/mlir/lib/Linalg/CMakeLists.txt b/mlir/lib/Linalg/CMakeLists.txt new file mode 100644 index 000000000000..b1df307ed7ab --- /dev/null +++ b/mlir/lib/Linalg/CMakeLists.txt @@ -0,0 +1,8 @@ +add_llvm_library(MLIRLinalg + LinalgOps.cpp + LinalgRegistration.cpp + LinalgTypes.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg + ) diff --git a/mlir/lib/Linalg/LinalgOps.cpp b/mlir/lib/Linalg/LinalgOps.cpp new file mode 100644 index 000000000000..bba47fbb4d5a --- /dev/null +++ b/mlir/lib/Linalg/LinalgOps.cpp @@ -0,0 +1,67 @@ +//===- LinalgOps.cpp - Implementation of the linalg operations ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a the Linalg operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Linalg/LinalgOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Linalg/LinalgTypes.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; + +void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min, + Value *max, Value *step) { + result->addOperands({min, max, step}); + result->addTypes({RangeType::get(b->getContext())}); +} + +// Verification is simply that a RangeOp takes 3 index ssa-value. +mlir::LogicalResult mlir::RangeOp::verify() { + if (!min() || !min()->getType().isa()) + return emitOpError("first operand should be of type index"); + if (!max() || !max()->getType().isa()) + return emitOpError("second operand should be of type index"); + if (!step() || !step()->getType().isa()) + return emitOpError("third operand should be of type index"); + return mlir::success(); +} + +// A RangeOp prints as: +// +// ```{.mlir} +// linalg.range %0:%1:%2 : !linalg.range +// ``` +void mlir::RangeOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step() + << " : " << getType(); +} + +bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector rangeInfo(3); + RangeType type; + auto affineIntTy = parser->getBuilder().getIndexType(); + return parser->parseOperand(rangeInfo[0]) || parser->parseColon() || + parser->parseOperand(rangeInfo[1]) || parser->parseColon() || + parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || + parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || + parser->addTypeToList(type, result->types); +} diff --git a/mlir/lib/Linalg/LinalgRegistration.cpp b/mlir/lib/Linalg/LinalgRegistration.cpp new file mode 100644 index 000000000000..3637037354fc --- /dev/null +++ b/mlir/lib/Linalg/LinalgRegistration.cpp @@ -0,0 +1,24 @@ +//===- LinalgRegistration.cpp - Register the linalg dialect statically ----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Linalg/LinalgOps.h" +#include "mlir/Linalg/LinalgTypes.h" + +using namespace mlir; + +// Static initialization for LinalgOps dialect registration. +static DialectRegistration LinalgOps; diff --git a/mlir/lib/Linalg/LinalgTypes.cpp b/mlir/lib/Linalg/LinalgTypes.cpp new file mode 100644 index 000000000000..7aabfd2fc38a --- /dev/null +++ b/mlir/lib/Linalg/LinalgTypes.cpp @@ -0,0 +1,53 @@ +//===- Dialect.cpp - Implementation of the linalg dialect and types -------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the Linalg dialect types and dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Linalg/LinalgTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Linalg/LinalgOps.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; + +mlir::LinalgDialect::LinalgDialect(MLIRContext *context) + : Dialect("linalg", context) { + addTypes(); + addOperations(); +} + +Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const { + MLIRContext *context = getContext(); + if (spec == "range") + return RangeType::get(getContext()); + return (context->emitError(loc, "unknown Linalg type: " + spec), Type()); +} + +/// RangeType prints as just "range". +static void print(RangeType rt, raw_ostream &os) { os << "range"; } + +void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const { + switch (type.getKind()) { + default: + llvm_unreachable("Unhandled Linalg type"); + case LinalgTypes::Range: + print(type.cast(), os); + break; + } +} diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir new file mode 100644 index 000000000000..f98558a779b4 --- /dev/null +++ b/mlir/test/Linalg/roundtrip.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s -verify | mlir-opt | FileCheck %s + +func @range(%arg0: index, %arg1: index, %arg2: index) { + %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range + return +} +// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) { +// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range \ No newline at end of file diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index a68b30de7c56..f6e9ac7e05b9 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -3,6 +3,7 @@ set(LIBS MLIRAnalysis MLIREDSC MLIRFxpMathOps + MLIRLinalg MLIRLLVMIR MLIRParser MLIRPass