mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-23 05:40:09 +00:00
[mlir][mesh] Add lowering of process multi-index op (#77490)
* Rename mesh.process_index -> mesh.process_multi_index. * Add mesh.process_linear_index op. * Add lowering of mesh.process_multi_index into an expression using mesh.process_linear_index, mesh.cluster_shape and affine.delinearize_index. This is useful to lower mesh ops and prepare them for further lowering where the runtime may have only the linear index of a device/process. For example in MPI we have a rank (linear index) in a communicator.
This commit is contained in:
parent
fef2fc3400
commit
79aa776267
@ -96,7 +96,8 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
|
||||
def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
|
||||
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
|
||||
let summary = "Get the shape of the cluster.";
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$mesh,
|
||||
@ -209,11 +210,15 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
|
||||
let summary = "Get the index of current device along specified mesh axis.";
|
||||
def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
|
||||
Pure,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||
]> {
|
||||
let summary = "Get the multi index of current device along specified mesh axes.";
|
||||
let description = [{
|
||||
It is used in the SPMD format of IR.
|
||||
The `axes` mush be non-negative and less than the total number of mesh axes.
|
||||
If the axes are empty then get the index along all axes.
|
||||
}];
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$mesh,
|
||||
@ -232,6 +237,27 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
|
||||
];
|
||||
}
|
||||
|
||||
def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
|
||||
Pure,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||
]> {
|
||||
let summary = "Get the linear index of the current device.";
|
||||
let description = [{
|
||||
Example:
|
||||
```
|
||||
%idx = mesh.process_linear_index on @mesh : index
|
||||
```
|
||||
if `@mesh` has shape `(10, 20, 30)`, a device with multi
|
||||
index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
|
||||
}];
|
||||
let arguments = (ins FlatSymbolRefAttr:$mesh);
|
||||
let results = (outs Index:$result);
|
||||
let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
|
||||
let builders = [
|
||||
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// collective communication ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
26
mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
Normal file
26
mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
Normal file
@ -0,0 +1,26 @@
|
||||
//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
|
||||
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
|
||||
|
||||
namespace mlir {
|
||||
class RewritePatternSet;
|
||||
class SymbolTableCollection;
|
||||
class DialectRegistry;
|
||||
namespace mesh {
|
||||
|
||||
void processMultiIndexOpLoweringPopulatePatterns(
|
||||
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
|
||||
|
||||
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry);
|
||||
|
||||
} // namespace mesh
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
|
@ -250,7 +250,8 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
||||
ClusterOp mesh) {
|
||||
build(odsBuilder, odsState,
|
||||
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
|
||||
mesh.getSymName(), MeshAxesAttr());
|
||||
mesh.getSymName(),
|
||||
MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
|
||||
}
|
||||
|
||||
void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
||||
@ -325,11 +326,11 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.process_index op
|
||||
// mesh.process_multi_index op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
|
||||
if (failed(mesh)) {
|
||||
return failure();
|
||||
@ -348,20 +349,38 @@ ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
||||
ClusterOp mesh) {
|
||||
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
||||
ClusterOp mesh) {
|
||||
build(odsBuilder, odsState,
|
||||
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
|
||||
mesh.getSymName(), MeshAxesAttr());
|
||||
}
|
||||
|
||||
void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
||||
StringRef mesh, ArrayRef<MeshAxis> axes) {
|
||||
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
||||
StringRef mesh, ArrayRef<MeshAxis> axes) {
|
||||
build(odsBuilder, odsState,
|
||||
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
|
||||
MeshAxesAttr::get(odsBuilder.getContext(), axes));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.process_linear_index op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
|
||||
if (failed(mesh)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
|
||||
OperationState &odsState, ClusterOp mesh) {
|
||||
build(odsBuilder, odsState, mesh.getSymName());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// collective communication ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
|
||||
Simplifications.cpp
|
||||
ShardingPropagation.cpp
|
||||
Spmdization.cpp
|
||||
Transforms.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
|
||||
@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
|
||||
MLIRShardingInterface
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAffineDialect
|
||||
MLIRArithDialect
|
||||
MLIRControlFlowDialect
|
||||
MLIRFuncDialect
|
||||
|
@ -1,4 +1,4 @@
|
||||
//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
|
||||
//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -206,8 +206,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
|
||||
|
||||
Value processIndexAlongAxis =
|
||||
builder
|
||||
.create<ProcessIndexOp>(mesh.getSymName(),
|
||||
SmallVector<MeshAxis>({splitMeshAxis}))
|
||||
.create<ProcessMultiIndexOp>(mesh.getSymName(),
|
||||
SmallVector<MeshAxis>({splitMeshAxis}))
|
||||
.getResult()[0];
|
||||
|
||||
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
|
||||
|
84
mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
Normal file
84
mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
Normal file
@ -0,0 +1,84 @@
|
||||
//===- Transforms.cpp ---------------------------------------------- C++ --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlir::mesh {
|
||||
|
||||
namespace {
|
||||
|
||||
/// Lower `mesh.process_multi_index` into expression using
|
||||
/// `mesh.process_linear_index` and `mesh.cluster_shape`.
|
||||
struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
|
||||
template <typename... OpRewritePatternArgs>
|
||||
ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
|
||||
OpRewritePatternArgs &&...opRewritePatternArgs)
|
||||
: OpRewritePattern(
|
||||
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
|
||||
symbolTableCollection(symbolTableCollection) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
ClusterOp mesh =
|
||||
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
|
||||
op.getOperation(), op.getMeshAttr());
|
||||
if (!mesh) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
|
||||
builder.setInsertionPointAfter(op.getOperation());
|
||||
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
|
||||
ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
|
||||
SmallVector<Value> completeMultiIndex =
|
||||
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
|
||||
.getMultiIndex();
|
||||
SmallVector<Value> multiIndex;
|
||||
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
|
||||
SmallVector<MeshAxis> opAxesIota;
|
||||
if (opMeshAxes.empty()) {
|
||||
opAxesIota.resize(mesh.getRank());
|
||||
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
|
||||
opMeshAxes = opAxesIota;
|
||||
}
|
||||
llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
|
||||
[&completeMultiIndex](MeshAxis meshAxis) {
|
||||
return completeMultiIndex[meshAxis];
|
||||
});
|
||||
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
SymbolTableCollection &symbolTableCollection;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void processMultiIndexOpLoweringPopulatePatterns(
|
||||
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
|
||||
patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry) {
|
||||
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
|
||||
}
|
||||
|
||||
} // namespace mlir::mesh
|
@ -128,9 +128,9 @@ func.func @cluster_shape_invalid_mesh_name() -> (index) {
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
|
||||
|
||||
func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
|
||||
func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
|
||||
// expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
|
||||
%0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
|
||||
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
|
||||
return %0#0, %0#1 : index, index
|
||||
}
|
||||
|
||||
@ -138,9 +138,9 @@ func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
|
||||
|
||||
mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
|
||||
|
||||
func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
|
||||
func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
|
||||
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
|
||||
%0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
|
||||
%0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
|
||||
return %0#0, %0#1, %0#2 : index, index, index
|
||||
}
|
||||
|
||||
@ -148,9 +148,9 @@ func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
|
||||
|
||||
func.func @process_index_wrong_number_of_results() -> (index, index) {
|
||||
func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
|
||||
// expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
|
||||
%0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
|
||||
%0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
|
||||
return %0#0, %0#1 : index, index
|
||||
}
|
||||
|
||||
@ -158,18 +158,26 @@ func.func @process_index_wrong_number_of_results() -> (index, index) {
|
||||
|
||||
mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
|
||||
|
||||
func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
|
||||
func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
|
||||
// expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
|
||||
%0:2 = mesh.process_index on @mesh0 : index, index
|
||||
%0:2 = mesh.process_multi_index on @mesh0 : index, index
|
||||
return %0#0, %0#1 : index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @process_index_invalid_mesh_name() -> (index) {
|
||||
func.func @process_multi_index_invalid_mesh_name() -> (index) {
|
||||
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
|
||||
%0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
|
||||
return %0#0 : index
|
||||
%0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @process_linear_index_invalid_mesh_name() -> (index) {
|
||||
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
|
||||
%0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -156,30 +156,37 @@ func.func @cluster_shape_empty_axes() -> (index, index, index) {
|
||||
return %0#0, %0#1, %0#2 : index, index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @process_index
|
||||
func.func @process_index() -> (index, index) {
|
||||
// CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
|
||||
%0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
|
||||
// CHECK-LABEL: func @process_multi_index
|
||||
func.func @process_multi_index() -> (index, index) {
|
||||
// CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
|
||||
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
|
||||
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
|
||||
return %0#0, %0#1 : index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @process_index_default_axes
|
||||
func.func @process_index_default_axes() -> (index, index, index) {
|
||||
// CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
|
||||
%0:3 = mesh.process_index on @mesh0 : index, index, index
|
||||
// CHECK-LABEL: func @process_multi_index_default_axes
|
||||
func.func @process_multi_index_default_axes() -> (index, index, index) {
|
||||
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
|
||||
%0:3 = mesh.process_multi_index on @mesh0 : index, index, index
|
||||
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
|
||||
return %0#0, %0#1, %0#2 : index, index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @process_index_empty_axes
|
||||
func.func @process_index_empty_axes() -> (index, index, index) {
|
||||
// CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
|
||||
%0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
|
||||
// CHECK-LABEL: func @process_multi_index_empty_axes
|
||||
func.func @process_multi_index_empty_axes() -> (index, index, index) {
|
||||
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
|
||||
%0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
|
||||
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
|
||||
return %0#0, %0#1, %0#2 : index, index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @process_linear_index
|
||||
func.func @process_linear_index() -> index {
|
||||
// CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
|
||||
%0 = mesh.process_linear_index on @mesh0 : index
|
||||
// CHECK: return %[[RES]] : index
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_reduce
|
||||
func.func @all_reduce(
|
||||
|
23
mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
Normal file
23
mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
Normal file
@ -0,0 +1,23 @@
|
||||
// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
|
||||
|
||||
mesh.cluster @mesh2d(rank = 2)
|
||||
|
||||
// CHECK-LABEL: func.func @multi_index_2d_mesh
|
||||
func.func @multi_index_2d_mesh() -> (index, index) {
|
||||
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
|
||||
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
|
||||
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
|
||||
%0:2 = mesh.process_multi_index on @mesh2d : index, index
|
||||
// CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
|
||||
return %0#0, %0#1 : index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
|
||||
func.func @multi_index_2d_mesh_single_inner_axis() -> index {
|
||||
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
|
||||
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
|
||||
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
|
||||
%0 = mesh.process_multi_index on @mesh2d axes = [0] : index
|
||||
// CHECK: return %[[MULTI_IDX]]#0 : index
|
||||
return %0 : index
|
||||
}
|
@ -21,7 +21,7 @@ func.func @split_replicated_tensor_axis(
|
||||
) -> tensor<3x14xf32> {
|
||||
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
|
||||
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
|
||||
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
|
||||
// CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
|
||||
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
|
||||
// CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
|
||||
@ -43,7 +43,7 @@ func.func @split_replicated_tensor_axis_dynamic(
|
||||
) -> tensor<?x3x?xf32> {
|
||||
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
|
||||
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
|
||||
// CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
|
||||
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
|
||||
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRMeshTest
|
||||
TestProcessMultiIndexOpLowering.cpp
|
||||
TestReshardingSpmdization.cpp
|
||||
TestSimplifications.cpp
|
||||
|
||||
|
@ -0,0 +1,55 @@
|
||||
//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
||||
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestMultiIndexOpLoweringPass
|
||||
: public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
|
||||
|
||||
void runOnOperation() override;
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<mesh::MeshDialect>();
|
||||
mesh::processMultiIndexOpLoweringRegisterDialects(registry);
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-mesh-process-multi-index-op-lowering";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test lowering of mesh.process_multi_index op.";
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TestMultiIndexOpLoweringPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
SymbolTableCollection symbolTableCollection;
|
||||
mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
|
||||
symbolTableCollection);
|
||||
LogicalResult status =
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestMultiIndexOpLoweringPass() {
|
||||
PassRegistration<TestMultiIndexOpLoweringPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
@ -120,6 +120,7 @@ void registerTestMemRefDependenceCheck();
|
||||
void registerTestMemRefStrideCalculation();
|
||||
void registerTestMeshSimplificationsPass();
|
||||
void registerTestMeshReshardingSpmdizationPass();
|
||||
void registerTestMultiIndexOpLoweringPass();
|
||||
void registerTestNextAccessPass();
|
||||
void registerTestOneToNTypeConversionPass();
|
||||
void registerTestOpaqueLoc();
|
||||
@ -240,6 +241,7 @@ void registerTestPasses() {
|
||||
mlir::test::registerTestMathPolynomialApproximationPass();
|
||||
mlir::test::registerTestMemRefDependenceCheck();
|
||||
mlir::test::registerTestMemRefStrideCalculation();
|
||||
mlir::test::registerTestMultiIndexOpLoweringPass();
|
||||
mlir::test::registerTestMeshSimplificationsPass();
|
||||
mlir::test::registerTestMeshReshardingSpmdizationPass();
|
||||
mlir::test::registerTestNextAccessPass();
|
||||
|
Loading…
Reference in New Issue
Block a user