mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-25 13:05:04 +00:00
Start a Linalg dialect
This CL starts implementing a Linalg dialect with the objective of supporting optimizing compilation of loops and library calls for a subset of common linear algebra operations. This CL starts by simply adding a linalg.range type and an operation with the proper roundtripping test. -- PiperOrigin-RevId: 244189468
This commit is contained in:
parent
05dfb1c7e0
commit
8370cc7492
@ -23,9 +23,9 @@
|
||||
namespace linalg {
|
||||
|
||||
enum LinalgTypes {
|
||||
Range = mlir::Type::FIRST_LINALG_TYPE,
|
||||
Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
|
||||
View,
|
||||
LAST_USED_LINALG_TYPE = View,
|
||||
FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View,
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
53
mlir/include/mlir/Linalg/LinalgOps.h
Normal file
53
mlir/include/mlir/Linalg/LinalgOps.h
Normal file
@ -0,0 +1,53 @@
|
||||
//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef MLIR_LINALG_LINALGOPS_H_
|
||||
#define MLIR_LINALG_LINALGOPS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// A RangeOp is used to create a value of RangeType from 3 values of type index
|
||||
/// that represent the min, max and step values of the range.
|
||||
class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
|
||||
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.range"; }
|
||||
static void build(Builder *b, OperationState *result, Value *min, Value *max,
|
||||
Value *step);
|
||||
LogicalResult verify();
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
Value *min() { return getOperand(0); }
|
||||
Value *max() { return getOperand(1); }
|
||||
Value *step() { return getOperand(2); }
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_LINALG_LINALGOPS_H_
|
59
mlir/include/mlir/Linalg/LinalgTypes.h
Normal file
59
mlir/include/mlir/Linalg/LinalgTypes.h
Normal file
@ -0,0 +1,59 @@
|
||||
//===- LinalgTypes.h - Linalg Types ---------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef MLIR_LINALG_LINALGTYPES_H_
|
||||
#define MLIR_LINALG_LINALGTYPES_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
|
||||
enum LinalgTypes {
|
||||
Range = Type::FIRST_LINALG_TYPE,
|
||||
LAST_USED_LINALG_TYPE = Range,
|
||||
};
|
||||
|
||||
class LinalgDialect : public Dialect {
|
||||
public:
|
||||
explicit LinalgDialect(MLIRContext *context);
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type parseType(llvm::StringRef spec, Location loc) const override;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void printType(Type type, llvm::raw_ostream &os) const override;
|
||||
};
|
||||
|
||||
/// A RangeType represents a minimal range abstraction (min, max, step).
|
||||
class RangeType : public Type::TypeBase<RangeType, Type> {
|
||||
public:
|
||||
// Used for generic hooks in TypeBase.
|
||||
using Base::Base;
|
||||
/// Construction hook.
|
||||
static RangeType get(MLIRContext *context) {
|
||||
/// Custom, uniq'ed construction in the MLIRContext.
|
||||
return Base::get(context, LinalgTypes::Range);
|
||||
}
|
||||
/// Used to implement llvm-style cast.
|
||||
static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_LINALG_LINALGTYPES_H_
|
@ -6,6 +6,7 @@ add_subdirectory(ExecutionEngine)
|
||||
add_subdirectory(FxpMathOps)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(Linalg)
|
||||
add_subdirectory(Parser)
|
||||
add_subdirectory(Pass)
|
||||
add_subdirectory(Quantization)
|
||||
|
8
mlir/lib/Linalg/CMakeLists.txt
Normal file
8
mlir/lib/Linalg/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
||||
add_llvm_library(MLIRLinalg
|
||||
LinalgOps.cpp
|
||||
LinalgRegistration.cpp
|
||||
LinalgTypes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg
|
||||
)
|
67
mlir/lib/Linalg/LinalgOps.cpp
Normal file
67
mlir/lib/Linalg/LinalgOps.cpp
Normal file
@ -0,0 +1,67 @@
|
||||
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a the Linalg operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Linalg/LinalgOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Linalg/LinalgTypes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
|
||||
Value *max, Value *step) {
|
||||
result->addOperands({min, max, step});
|
||||
result->addTypes({RangeType::get(b->getContext())});
|
||||
}
|
||||
|
||||
// Verification is simply that a RangeOp takes 3 index ssa-value.
|
||||
mlir::LogicalResult mlir::RangeOp::verify() {
|
||||
if (!min() || !min()->getType().isa<IndexType>())
|
||||
return emitOpError("first operand should be of type index");
|
||||
if (!max() || !max()->getType().isa<IndexType>())
|
||||
return emitOpError("second operand should be of type index");
|
||||
if (!step() || !step()->getType().isa<IndexType>())
|
||||
return emitOpError("third operand should be of type index");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// A RangeOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.range %0:%1:%2 : !linalg.range
|
||||
// ```
|
||||
void mlir::RangeOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
|
||||
<< " : " << getType();
|
||||
}
|
||||
|
||||
bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
|
||||
RangeType type;
|
||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
|
||||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
|
||||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
|
||||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
|
||||
parser->addTypeToList(type, result->types);
|
||||
}
|
24
mlir/lib/Linalg/LinalgRegistration.cpp
Normal file
24
mlir/lib/Linalg/LinalgRegistration.cpp
Normal file
@ -0,0 +1,24 @@
|
||||
//===- LinalgRegistration.cpp - Register the linalg dialect statically ----===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/Linalg/LinalgOps.h"
|
||||
#include "mlir/Linalg/LinalgTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Static initialization for LinalgOps dialect registration.
|
||||
static DialectRegistration<LinalgDialect> LinalgOps;
|
53
mlir/lib/Linalg/LinalgTypes.cpp
Normal file
53
mlir/lib/Linalg/LinalgTypes.cpp
Normal file
@ -0,0 +1,53 @@
|
||||
//===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the Linalg dialect types and dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Linalg/LinalgTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Linalg/LinalgOps.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
|
||||
: Dialect("linalg", context) {
|
||||
addTypes<RangeType>();
|
||||
addOperations<RangeOp>();
|
||||
}
|
||||
|
||||
Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
|
||||
MLIRContext *context = getContext();
|
||||
if (spec == "range")
|
||||
return RangeType::get(getContext());
|
||||
return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
|
||||
}
|
||||
|
||||
/// RangeType prints as just "range".
|
||||
static void print(RangeType rt, raw_ostream &os) { os << "range"; }
|
||||
|
||||
void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
|
||||
switch (type.getKind()) {
|
||||
default:
|
||||
llvm_unreachable("Unhandled Linalg type");
|
||||
case LinalgTypes::Range:
|
||||
print(type.cast<RangeType>(), os);
|
||||
break;
|
||||
}
|
||||
}
|
8
mlir/test/Linalg/roundtrip.mlir
Normal file
8
mlir/test/Linalg/roundtrip.mlir
Normal file
@ -0,0 +1,8 @@
|
||||
// RUN: mlir-opt %s -verify | mlir-opt | FileCheck %s
|
||||
|
||||
func @range(%arg0: index, %arg1: index, %arg2: index) {
|
||||
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
|
||||
// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
@ -3,6 +3,7 @@ set(LIBS
|
||||
MLIRAnalysis
|
||||
MLIREDSC
|
||||
MLIRFxpMathOps
|
||||
MLIRLinalg
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
|
Loading…
x
Reference in New Issue
Block a user