mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-01 22:53:29 +00:00
[mlir] Add maxnumf
and minnumf
to AtomicRMWKind
(#66442)
This commit adds the mentioned kinds of `AtomicRMWKind` as well as code generation for them.
This commit is contained in:
parent
52b33ff760
commit
01e80a0f41
@ -82,6 +82,8 @@ def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
|
||||
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
|
||||
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
|
||||
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
|
||||
def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
|
||||
def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
|
||||
|
||||
def AtomicRMWKindAttr : I64EnumAttr<
|
||||
"AtomicRMWKind", "",
|
||||
@ -89,7 +91,7 @@ def AtomicRMWKindAttr : I64EnumAttr<
|
||||
ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
|
||||
ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
|
||||
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
|
||||
ATOMIC_RMW_KIND_ANDI]> {
|
||||
ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
|
||||
let cppNamespace = "::mlir::arith";
|
||||
}
|
||||
|
||||
|
@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
|
||||
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::minimumf:
|
||||
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxnumf:
|
||||
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::minnumf:
|
||||
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxs:
|
||||
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::mins:
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace memref {
|
||||
@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
|
||||
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
|
||||
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
|
||||
[](memref::AtomicRMWOp op) {
|
||||
return op.getKind() != arith::AtomicRMWKind::maximumf &&
|
||||
op.getKind() != arith::AtomicRMWKind::minimumf;
|
||||
constexpr std::array shouldBeExpandedKinds = {
|
||||
arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
|
||||
arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
|
||||
return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
|
||||
});
|
||||
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
|
||||
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
|
||||
|
@ -3,9 +3,11 @@
|
||||
// CHECK-LABEL: func @atomic_rmw_to_generic
|
||||
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
|
||||
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
|
||||
%x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
%y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
return %x : f32
|
||||
%a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
%b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
%c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
%d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
return %a : f32
|
||||
}
|
||||
// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
|
||||
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
|
||||
@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32
|
||||
// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
|
||||
// CHECK: memref.atomic_yield [[MINIMUM]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
|
||||
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
|
||||
// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
|
||||
// CHECK: memref.atomic_yield [[MAXNUM]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
|
||||
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
|
||||
// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
|
||||
// CHECK: memref.atomic_yield [[MINNUM]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: return [[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
Loading…
x
Reference in New Issue
Block a user