[mlir][spirv] NFC: Move GLSL canonicalization pass to Transforms/

This is a pass that can be used by downstream consumers directly
to avoid the boilerplate to wrap around the `populate*Patterns`.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D121222
This commit is contained in:
Lei Zhang 2022-03-08 13:45:19 -05:00
parent 79d08e398c
commit 86fe16b67d
9 changed files with 52 additions and 48 deletions

View File

@ -23,8 +23,11 @@
namespace mlir {
namespace spirv {
void populateSPIRVGLSLCanonicalizationPatterns(
mlir::RewritePatternSet &results);
/// Populates patterns to run canoncalization that involves GLSL ops.
///
/// These patterns cannot be run in default canonicalization because GLSL ops
/// aren't always available. So they should be involed specifically when needed.
void populateSPIRVGLSLCanonicalizationPatterns(RewritePatternSet &results);
} // namespace spirv
} // namespace mlir

View File

@ -24,6 +24,11 @@ class ModuleOp;
// Passes
//===----------------------------------------------------------------------===//
/// Creates a pass to run canoncalization patterns that involve GLSL ops.
/// These patterns cannot be run in default canonicalization because GLSL ops
/// aren't always available. So they should be involed specifically when needed.
std::unique_ptr<OperationPass<>> createCanonicalizeGLSLPass();
/// Creates a module pass that converts composite types used by objects in the
/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
/// classes with layout information.

View File

@ -17,6 +17,11 @@ def SPIRVCompositeTypeLayout
let constructor = "mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass()";
}
def SPIRVCanonicalizeGLSL : Pass<"spirv-canonicalize-glsl", ""> {
let summary = "Run canonicalization involving GLSL ops";
let constructor = "mlir::spirv::createCanonicalizeGLSLPass()";
}
def SPIRVLowerABIAttributes : Pass<"spirv-lower-abi-attrs", "spirv::ModuleOp"> {
let summary = "Decorate SPIR-V composite type with layout info";
let constructor = "mlir::spirv::createLowerABIAttributesPass()";

View File

@ -1,4 +1,5 @@
set(LLVM_OPTIONAL_SOURCES
CanonicalizeGLSLPass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
@ -19,6 +20,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
)
add_mlir_dialect_library(MLIRSPIRVTransforms
CanonicalizeGLSLPass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp

View File

@ -0,0 +1,34 @@
//===- CanonicalizeGLSLPass.cpp - GLSL Related Canonicalization Pass ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
class CanonicalizeGLSLPass final
: public SPIRVCanonicalizeGLSLBase<CanonicalizeGLSLPass> {
public:
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<>> spirv::createCanonicalizeGLSLPass() {
return std::make_unique<CanonicalizeGLSLPass>();
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -test-spirv-glsl-canonicalization -split-input-file -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt -split-input-file -spirv-canonicalize-glsl %s | FileCheck %s
// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32)
func @clamp_fordlessthan(%input: f32) -> f32 {

View File

@ -2,7 +2,6 @@
add_mlir_library(MLIRSPIRVTestPasses
TestAvailability.cpp
TestEntryPointAbi.cpp
TestGLSLCanonicalization.cpp
TestModuleCombiner.cpp
EXCLUDE_FROM_LIBMLIR

View File

@ -1,42 +0,0 @@
//===- TestGLSLCanonicalization.cpp - Pass to test GLSL-specific pattterns ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
class TestGLSLCanonicalizationPass
: public PassWrapper<TestGLSLCanonicalizationPass,
OperationPass<mlir::ModuleOp>> {
public:
TestGLSLCanonicalizationPass() = default;
TestGLSLCanonicalizationPass(const TestGLSLCanonicalizationPass &) {}
void runOnOperation() override;
StringRef getArgument() const final {
return "test-spirv-glsl-canonicalization";
}
StringRef getDescription() const final {
return "Tests SPIR-V canonicalization patterns for GLSL extension.";
}
};
} // namespace
void TestGLSLCanonicalizationPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
namespace mlir {
void registerTestSpirvGLSLCanonicalizationPass() {
PassRegistration<TestGLSLCanonicalizationPass>();
}
} // namespace mlir

View File

@ -49,7 +49,6 @@ void registerTestPrintInvalidPass();
void registerTestPrintNestingPass();
void registerTestReducer();
void registerTestSpirvEntryPointABIPass();
void registerTestSpirvGLSLCanonicalizationPass();
void registerTestSpirvModuleCombinerPass();
void registerTestTraitsPass();
void registerTosaTestQuantUtilAPIPass();
@ -137,7 +136,6 @@ void registerTestPasses() {
registerTestPrintNestingPass();
registerTestReducer();
registerTestSpirvEntryPointABIPass();
registerTestSpirvGLSLCanonicalizationPass();
registerTestSpirvModuleCombinerPass();
registerTestTraitsPass();
registerVectorizerTestPass();