[mlir] Add maxnumf and minnumf to AtomicRMWKind (#66442)

This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.
This commit is contained in:
Daniil Dudkin 2023-09-15 22:41:51 +03:00 committed by GitHub
parent 52b33ff760
commit 01e80a0f41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 6 deletions

View File

@ -82,6 +82,8 @@ def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
@ -89,7 +91,7 @@ def AtomicRMWKindAttr : I64EnumAttr<
ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
ATOMIC_RMW_KIND_ANDI]> {
ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
let cppNamespace = "::mlir::arith";
}

View File

@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minimumf:
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxnumf:
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minnumf:
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxs:
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
case AtomicRMWKind::mins:

View File

@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace memref {
@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
[](memref::AtomicRMWOp op) {
return op.getKind() != arith::AtomicRMWKind::maximumf &&
op.getKind() != arith::AtomicRMWKind::minimumf;
constexpr std::array shouldBeExpandedKinds = {
arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();

View File

@ -3,9 +3,11 @@
// CHECK-LABEL: func @atomic_rmw_to_generic
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
%a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %a : f32
}
// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32
// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
// CHECK: memref.atomic_yield [[MINIMUM]] : f32
// CHECK: }
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
// CHECK: memref.atomic_yield [[MAXNUM]] : f32
// CHECK: }
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
// CHECK: memref.atomic_yield [[MINNUM]] : f32
// CHECK: }
// CHECK: return [[RESULT]] : f32
// -----