[mlir][mesh] Add lowering of process multi-index op (#77490)

* Rename mesh.process_index -> mesh.process_multi_index.
* Add mesh.process_linear_index op.
* Add lowering of mesh.process_multi_index into an expression using
mesh.process_linear_index, mesh.cluster_shape and
affine.delinearize_index.

This is useful to lower mesh ops and prepare them for further lowering
where the runtime may have only the linear index of a device/process.
For example in MPI we have a rank (linear index) in a communicator.
This commit is contained in:
Boian Petkantchin 2024-01-10 07:01:16 -08:00 committed by GitHub
parent fef2fc3400
commit 79aa776267
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 291 additions and 38 deletions

View File

@ -96,7 +96,8 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}
def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Get the shape of the cluster.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
@ -209,11 +210,15 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Get the index of current device along specified mesh axis.";
def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Get the multi index of current device along specified mesh axes.";
let description = [{
It is used in the SPMD format of IR.
The `axes` mush be non-negative and less than the total number of mesh axes.
If the axes are empty then get the index along all axes.
}];
let arguments = (ins
FlatSymbolRefAttr:$mesh,
@ -232,6 +237,27 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
];
}
def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Get the linear index of the current device.";
let description = [{
Example:
```
%idx = mesh.process_linear_index on @mesh : index
```
if `@mesh` has shape `(10, 20, 30)`, a device with multi
index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
}];
let arguments = (ins FlatSymbolRefAttr:$mesh);
let results = (outs Index:$result);
let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
let builders = [
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
];
}
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,26 @@
//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
namespace mlir {
class RewritePatternSet;
class SymbolTableCollection;
class DialectRegistry;
namespace mesh {
void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
} // namespace mesh
} // namespace mlir
#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H

View File

@ -250,7 +250,8 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
ClusterOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(), MeshAxesAttr());
mesh.getSymName(),
MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
}
void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
@ -325,11 +326,11 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
}
//===----------------------------------------------------------------------===//
// mesh.process_index op
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//
LogicalResult
ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
@ -348,20 +349,38 @@ ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
ClusterOp mesh) {
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
ClusterOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(), MeshAxesAttr());
}
void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) {
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
//===----------------------------------------------------------------------===//
// mesh.process_linear_index op
//===----------------------------------------------------------------------===//
LogicalResult
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
}
return success();
}
void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState, ClusterOp mesh) {
build(odsBuilder, odsState, mesh.getSymName());
}
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//

View File

@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
Simplifications.cpp
ShardingPropagation.cpp
Spmdization.cpp
Transforms.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
MLIRShardingInterface
LINK_LIBS PUBLIC
MLIRAffineDialect
MLIRArithDialect
MLIRControlFlowDialect
MLIRFuncDialect

View File

@ -1,4 +1,4 @@
//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -206,8 +206,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Value processIndexAlongAxis =
builder
.create<ProcessIndexOp>(mesh.getSymName(),
SmallVector<MeshAxis>({splitMeshAxis}))
.create<ProcessMultiIndexOp>(mesh.getSymName(),
SmallVector<MeshAxis>({splitMeshAxis}))
.getResult()[0];
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(

View File

@ -0,0 +1,84 @@
//===- Transforms.cpp ---------------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <numeric>
namespace mlir::mesh {
namespace {
/// Lower `mesh.process_multi_index` into expression using
/// `mesh.process_linear_index` and `mesh.cluster_shape`.
struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
template <typename... OpRewritePatternArgs>
ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
OpRewritePatternArgs &&...opRewritePatternArgs)
: OpRewritePattern(
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
symbolTableCollection(symbolTableCollection) {}
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
ClusterOp mesh =
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), op.getMeshAttr());
if (!mesh) {
return failure();
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
SmallVector<Value> multiIndex;
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
SmallVector<MeshAxis> opAxesIota;
if (opMeshAxes.empty()) {
opAxesIota.resize(mesh.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
opMeshAxes = opAxesIota;
}
llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
[&completeMultiIndex](MeshAxis meshAxis) {
return completeMultiIndex[meshAxis];
});
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
}
private:
SymbolTableCollection &symbolTableCollection;
};
} // namespace
void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
patterns.getContext());
}
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry) {
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
}
} // namespace mlir::mesh

View File

@ -128,9 +128,9 @@ func.func @cluster_shape_invalid_mesh_name() -> (index) {
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
// expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
%0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
return %0#0, %0#1 : index, index
}
@ -138,9 +138,9 @@ func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
%0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
%0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
@ -148,9 +148,9 @@ func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @process_index_wrong_number_of_results() -> (index, index) {
func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
%0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
%0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
return %0#0, %0#1 : index, index
}
@ -158,18 +158,26 @@ func.func @process_index_wrong_number_of_results() -> (index, index) {
mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
// expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
%0:2 = mesh.process_index on @mesh0 : index, index
%0:2 = mesh.process_multi_index on @mesh0 : index, index
return %0#0, %0#1 : index, index
}
// -----
func.func @process_index_invalid_mesh_name() -> (index) {
func.func @process_multi_index_invalid_mesh_name() -> (index) {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
%0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
return %0#0 : index
%0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
return %0 : index
}
// -----
func.func @process_linear_index_invalid_mesh_name() -> (index) {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
%0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index
return %0 : index
}
// -----

View File

@ -156,30 +156,37 @@ func.func @cluster_shape_empty_axes() -> (index, index, index) {
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_index
func.func @process_index() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
%0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
// CHECK-LABEL: func @process_multi_index
func.func @process_multi_index() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func @process_index_default_axes
func.func @process_index_default_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
%0:3 = mesh.process_index on @mesh0 : index, index, index
// CHECK-LABEL: func @process_multi_index_default_axes
func.func @process_multi_index_default_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
%0:3 = mesh.process_multi_index on @mesh0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_index_empty_axes
func.func @process_index_empty_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
%0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
// CHECK-LABEL: func @process_multi_index_empty_axes
func.func @process_multi_index_empty_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
%0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
// CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
%0 = mesh.process_linear_index on @mesh0 : index
// CHECK: return %[[RES]] : index
return %0 : index
}
// CHECK-LABEL: func @all_reduce
func.func @all_reduce(

View File

@ -0,0 +1,23 @@
// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
mesh.cluster @mesh2d(rank = 2)
// CHECK-LABEL: func.func @multi_index_2d_mesh
func.func @multi_index_2d_mesh() -> (index, index) {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0:2 = mesh.process_multi_index on @mesh2d : index, index
// CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
func.func @multi_index_2d_mesh_single_inner_axis() -> index {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0 = mesh.process_multi_index on @mesh2d axes = [0] : index
// CHECK: return %[[MULTI_IDX]]#0 : index
return %0 : index
}

View File

@ -21,7 +21,7 @@ func.func @split_replicated_tensor_axis(
) -> tensor<3x14xf32> {
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
// CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
// CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
@ -43,7 +43,7 @@ func.func @split_replicated_tensor_axis_dynamic(
) -> tensor<?x3x?xf32> {
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
// CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index

View File

@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTest
TestProcessMultiIndexOpLowering.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp

View File

@ -0,0 +1,55 @@
//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestMultiIndexOpLoweringPass
: public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
void runOnOperation() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mesh::MeshDialect>();
mesh::processMultiIndexOpLoweringRegisterDialects(registry);
}
StringRef getArgument() const final {
return "test-mesh-process-multi-index-op-lowering";
}
StringRef getDescription() const final {
return "Test lowering of mesh.process_multi_index op.";
}
};
} // namespace
void TestMultiIndexOpLoweringPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
symbolTableCollection);
LogicalResult status =
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
}
namespace mlir {
namespace test {
void registerTestMultiIndexOpLoweringPass() {
PassRegistration<TestMultiIndexOpLoweringPass>();
}
} // namespace test
} // namespace mlir

View File

@ -120,6 +120,7 @@ void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
void registerTestMultiIndexOpLoweringPass();
void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
@ -240,6 +241,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
mlir::test::registerTestMultiIndexOpLoweringPass();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();