diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td index a833e9c8220a..133af893e4ef 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -82,6 +82,8 @@ def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>; def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>; +def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>; +def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>; def AtomicRMWKindAttr : I64EnumAttr< "AtomicRMWKind", "", @@ -89,7 +91,7 @@ def AtomicRMWKindAttr : I64EnumAttr< ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI, - ATOMIC_RMW_KIND_ANDI]> { + ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> { let cppNamespace = "::mlir::arith"; } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d39c5b605112..ae8a6ef350ce 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return builder.create(loc, lhs, rhs); case AtomicRMWKind::minimumf: return builder.create(loc, lhs, rhs); + case AtomicRMWKind::maxnumf: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::minnumf: + return builder.create(loc, lhs, rhs); case AtomicRMWKind::maxs: return builder.create(loc, lhs, rhs); case AtomicRMWKind::mins: diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index b3beaada2539..faba12f5bf82 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { namespace memref { @@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase { target.addLegalDialect(); target.addDynamicallyLegalOp( [](memref::AtomicRMWOp op) { - return op.getKind() != arith::AtomicRMWKind::maximumf && - op.getKind() != arith::AtomicRMWKind::minimumf; + constexpr std::array shouldBeExpandedKinds = { + arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf, + arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf}; + return !llvm::is_contained(shouldBeExpandedKinds, op.getKind()); }); target.addDynamicallyLegalOp([](memref::ReshapeOp op) { return !cast(op.getShape().getType()).hasStaticShape(); diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir index 6c98cf978505..f958a92b751a 100644 --- a/mlir/test/Dialect/MemRef/expand-ops.mlir +++ b/mlir/test/Dialect/MemRef/expand-ops.mlir @@ -3,9 +3,11 @@ // CHECK-LABEL: func @atomic_rmw_to_generic // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { - %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - return %x : f32 + %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + return %a : f32 } // CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { // CHECK: ^bb0([[CUR_VAL:%.*]]: f32): @@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 // CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32 // CHECK: memref.atomic_yield [[MINIMUM]] : f32 // CHECK: } +// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): +// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32 +// CHECK: memref.atomic_yield [[MAXNUM]] : f32 +// CHECK: } +// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): +// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32 +// CHECK: memref.atomic_yield [[MINNUM]] : f32 +// CHECK: } // CHECK: return [[RESULT]] : f32 // -----