[mlir] reorgnize Linalg TransformOps files. NFC

Mirror the separation between LinalgTransformOps and LinalgMatchOps in
headers. Create a separate pair of files for the extension.

Depends on D148017

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D148075
This commit is contained in:
Alex Zinenko 2023-04-12 08:10:24 +00:00
parent dcfdb963d4
commit 135d29c3f5
9 changed files with 131 additions and 71 deletions

View File

@ -0,0 +1,15 @@
//===- DialectExtension.h - Linalg transform dialect extension --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
namespace mlir {
class DialectRegistry;
namespace linalg {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace linalg
} // namespace mlir

View File

@ -0,0 +1,48 @@
//===- LinalgMatchOps.h - Linalg transform matcher ops ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H
#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
namespace mlir {
namespace transform {
namespace detail {
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op,
Value structuredOpHandle);
} // namespace detail
template <typename OpTy>
class StructuredOpPredicateOpTrait
: public OpTrait::TraitBase<OpTy, StructuredOpPredicateOpTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(
OpTy::template hasTrait<SingleOpMatcherOpTrait>(),
"StructuredOpPredicateOpTrait requires SingleOpMatcherOpTrait");
return detail::verifyStructuredOpPredicateOpTrait(
op, cast<OpTy>(op).getOperandHandle());
}
};
} // namespace transform
} // namespace mlir
//===----------------------------------------------------------------------===//
// Linalg Matcher Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc"
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H

View File

@ -11,7 +11,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
@ -56,30 +55,7 @@ DiagnosedSilenceableFailure tileToForallOpImpl(
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
namespace detail {
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op,
Value structuredOpHandle);
} // namespace detail
template <typename OpTy>
class StructuredOpPredicateOpTrait
: public OpTrait::TraitBase<OpTy, StructuredOpPredicateOpTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(
OpTy::template hasTrait<SingleOpMatcherOpTrait>(),
"StructuredOpPredicateOpTrait requires SingleOpMatcherOpTrait");
return detail::verifyStructuredOpPredicateOpTrait(
op, cast<OpTy>(op).getOperandHandle());
}
};
} // namespace transform
namespace linalg {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace linalg
} // namespace mlir
//===----------------------------------------------------------------------===//
@ -91,7 +67,4 @@ void registerTransformDialectExtension(DialectRegistry &registry);
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc"
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H

View File

@ -42,7 +42,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRLinalgTransformOps
DialectExtension.cpp
LinalgMatchOps.cpp
LinalgTransformOps.cpp

View File

@ -0,0 +1,59 @@
//===- DialectExtension.cpp - Linalg transform dialect extension ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
namespace {
/// Registers new ops and declares PDL as dependent dialect since the
/// additional ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
public:
using Base::Base;
void init() {
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<linalg::LinalgDialect>();
declareGeneratedDialect<affine::AffineDialect>();
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<scf::SCFDialect>();
declareGeneratedDialect<vector::VectorDialect>();
declareGeneratedDialect<gpu::GPUDialect>();
declareGeneratedDialect<tensor::TensorDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
>();
}
};
} // namespace
void mlir::linalg::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<LinalgTransformDialectExtension>();
}

View File

@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/FunctionImplementation.h"

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@ -3113,46 +3114,7 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
return diag;
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
/// Registers new ops and declares PDL as dependent dialect since the
/// additional ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
public:
using Base::Base;
void init() {
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<LinalgDialect>();
declareGeneratedDialect<affine::AffineDialect>();
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<scf::SCFDialect>();
declareGeneratedDialect<vector::VectorDialect>();
declareGeneratedDialect<gpu::GPUDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
>();
}
};
} // namespace
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
void mlir::linalg::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<LinalgTransformDialectExtension>();
}

View File

@ -8787,9 +8787,9 @@ cc_library(
srcs = glob([
"lib/Dialect/Linalg/TransformOps/*.cpp",
]),
hdrs = [
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h",
],
hdrs = glob([
"include/mlir/Dialect/Linalg/TransformOps/*.h",
]),
includes = ["include"],
deps = [
":AffineDialect",
@ -8807,6 +8807,7 @@ cc_library(
":LinalgTransforms",
":LinalgUtils",
":PDLDialect",
":SCFDialect",
":SCFTransforms",
":Support",
":TensorDialect",
@ -8815,6 +8816,7 @@ cc_library(
":TransformDialect",
":TransformDialectUtils",
":TransformUtils",
":VectorDialect",
":VectorTransforms",
"//llvm:Support",
],