[mlir][nvvm] Implement mbarrier.init

NV GPUs provides split arrive/wait barriers that one can syncronize a subgroup of threads in CTA. It is particularly important for Hopper GPUs and allows tracking engines like TMA. See for more details:
https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier

This initial implementation sets the foundation for future enhancements and additions.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D151334
This commit is contained in:
Guray Ozen 2023-06-16 10:03:30 +02:00
parent da7892f729
commit 58950d4add
2 changed files with 40 additions and 0 deletions

View File

@ -19,6 +19,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>;
def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>;
def LLVM_i64ptr_any : LLVM_IntPtrBase<64>;
def LLVM_i64ptr_shared : LLVM_IntPtrBase<64, 3>;
//===----------------------------------------------------------------------===//
// NVVM dialect definitions
@ -173,6 +175,28 @@ def NVVM_ReduxOp :
}];
}
//===----------------------------------------------------------------------===//
// NVVM Split arrive/wait barrier
//===----------------------------------------------------------------------===//
/// mbarrier.init instruction with generic pointer type
def NVVM_MBarrierInitOp : NVVM_Op<"mbarrier.init">,
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
}];
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
}
/// mbarrier.init instruction with shared pointer type
def NVVM_MBarrierInitSharedOp : NVVM_Op<"mbarrier.init.shared">,
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
}];
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
}
//===----------------------------------------------------------------------===//
// NVVM synchronization op definitions
//===----------------------------------------------------------------------===//

View File

@ -337,3 +337,19 @@ llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
// expected-error@below {{attribute attached to unexpected op}}
func.func private @expected_llvm_func() attributes { nvvm.kernel }
// -----
llvm.func private @mbarrier_init_generic(%barrier: !llvm.ptr) {
%count = nvvm.read.ptx.sreg.ntid.x : i32
// CHECK: nvvm.mbarrier.init %{{.*}}, %{{.*}} : !llvm.ptr, i32
nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32
llvm.return
}
llvm.func private @mbarrier_init_shared(%barrier: !llvm.ptr<3>) {
%count = nvvm.read.ptx.sreg.ntid.x : i32
// CHECK: nvvm.mbarrier.init.shared %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32
nvvm.mbarrier.init.shared %barrier, %count : !llvm.ptr<3>, i32
llvm.return
}