[MLIR] Add std.atomic_rmw op

Summary:
The RFC for this op is here: https://llvm.discourse.group/t/rfc-add-std-atomic-rmw-op/489

The std.atmomic_rmw op provides a way to support read-modify-write
sequences with data race freedom. It is intended to be used in the lowering
of an upcoming affine.atomic_rmw op which can be used for reductions.

A lowering to LLVM is provided with 2 paths:
- Simple patterns: llvm.atomicrmw
- Everything else: llvm.cmpxchg

Differential Revision: https://reviews.llvm.org/D74401
This commit is contained in:
Frank Laub 2020-02-24 16:49:52 -08:00
parent 4e45ef4d77
commit fe210a1ff2
6 changed files with 347 additions and 9 deletions

View File

@ -218,6 +218,69 @@ def AndOp : IntArithmeticOp<"and", [Commutative]> {
let hasFolder = 1;
}
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI]> {
let cppNamespace = "::mlir";
}
def AtomicRMWOp : Std_Op<"atomic_rmw", [
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
"$_self.cast<MemRefType>().getElementType()">
]> {
let summary = "atomic read-modify-write operation";
let description = [{
The "atomic_rmw" operation provides a way to perform a read-modify-write
sequence that is free from data races. The kind enumeration specifies the
modification to perform. The value operand represents the new value to be
applied during the modification. The memref operand represents the buffer
that the read and write will be performed against, as accessed by the
specified indices. The arity of the indices is the rank of the memref. The
result represents the latest value that was stored.
Example:
```mlir
%x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
```
}];
let arguments = (ins
AtomicRMWKindAttr:$kind,
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
Variadic<Index>:$indices);
let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
let assemblyFormat = [{
$kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
type($memref) `)` `->` type($result)
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();
}
}];
}
def BranchOp : Std_Op<"br", [Terminator]> {
let summary = "branch operation";
let description = [{

View File

@ -1143,7 +1143,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
}
};
template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
template <typename SourceOp, unsigned OpCount>
struct OpCountValidator {
static_assert(
std::is_base_of<
typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
@ -1151,12 +1152,14 @@ template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
"wrong operand count");
};
template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
template <typename SourceOp>
struct OpCountValidator<SourceOp, 1> {
static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
"expected a single operand");
};
template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
template <typename SourceOp, unsigned OpCount>
void ValidateOpCount() {
OpCountValidator<SourceOp, OpCount>();
}
@ -1524,11 +1527,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
if (strides[index] == MemRefType::getDynamicStrideOrOffset())
// Identity layout map is enforced in the match function, so we compute:
// `runningStride *= sizes[index + 1]`
runningStride =
runningStride
? rewriter.create<LLVM::MulOp>(loc, runningStride,
sizes[index + 1])
: createIndexConstant(rewriter, loc, 1);
runningStride = runningStride
? rewriter.create<LLVM::MulOp>(loc, runningStride,
sizes[index + 1])
: createIndexConstant(rewriter, loc, 1);
else
runningStride = createIndexConstant(rewriter, loc, strides[index]);
strideValues[index] = runningStride;
@ -2537,6 +2539,170 @@ struct AssumeAlignmentOpLowering
} // namespace
/// Try to match the kind of a std.atomic_rmw to determine whether to use a
/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
switch (atomicOp.kind()) {
case AtomicRMWKind::addf:
return LLVM::AtomicBinOp::fadd;
case AtomicRMWKind::addi:
return LLVM::AtomicBinOp::add;
case AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
case AtomicRMWKind::minu:
return LLVM::AtomicBinOp::umin;
default:
return llvm::None;
}
llvm_unreachable("Invalid AtomicRMWKind");
}
namespace {
struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto atomicOp = cast<AtomicRMWOp>(op);
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
return matchFailure();
OperandAdaptor<AtomicRMWOp> adaptor(operands);
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
op, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
return matchSuccess();
}
};
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
/// | v v
/// | +--------------------------------+
/// | | loop(%loaded): |
/// | | <body contents> |
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |
/// v
/// +--------------------------------+
/// | end: |
/// | <code after the AtomicRMWOp> |
/// +--------------------------------+
///
struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto atomicOp = cast<AtomicRMWOp>(op);
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (maybeKind)
return matchFailure();
LLVM::FCmpPredicate predicate;
switch (atomicOp.kind()) {
case AtomicRMWKind::maxf:
predicate = LLVM::FCmpPredicate::ogt;
break;
case AtomicRMWKind::minf:
predicate = LLVM::FCmpPredicate::olt;
break;
default:
return matchFailure();
}
OperandAdaptor<AtomicRMWOp> adaptor(operands);
auto loc = op->getLoc();
auto valueType = adaptor.value().getType().cast<LLVM::LLVMType>();
// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
auto initPosition = rewriter.getInsertionPoint();
auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
auto loopArgument = loopBlock->addArgument(valueType);
auto loopPosition = rewriter.getInsertionPoint();
auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
auto init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
std::array<Value, 1> brRegionOperands{init};
std::array<ValueRange, 1> brOperands{brRegionOperands};
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{}, loopBlock, brOperands);
// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);
auto predicateI64 =
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate));
auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
auto lhs = loopArgument;
auto rhs = adaptor.value();
auto cmp =
rewriter.create<LLVM::FCmpOp>(loc, boolType, predicateI64, lhs, rhs);
auto select = rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
// Prepare the epilog of the loop block.
rewriter.setInsertionPointToEnd(loopBlock);
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, select, successOrdering,
failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
auto newLoaded = rewriter.create<LLVM::ExtractValueOp>(
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
auto ok = rewriter.create<LLVM::ExtractValueOp>(
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
// Conditionally branch to the end or back to the loop depending on %ok.
std::array<Value, 1> condBrProperOperands{ok};
std::array<Block *, 2> condBrDestinations{endBlock, loopBlock};
std::array<Value, 1> condBrRegionOperands{newLoaded};
std::array<ValueRange, 2> condBrOperands{ArrayRef<Value>{},
condBrRegionOperands};
rewriter.create<LLVM::CondBrOp>(loc, condBrProperOperands,
condBrDestinations, condBrOperands);
// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(op, {newLoaded});
return matchSuccess();
}
};
} // namespace
static void ensureDistinctSuccessors(Block &bb) {
auto *terminator = bb.getTerminator();
@ -2594,6 +2760,8 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
AddFOpLowering,
AddIOpLowering,
AndOpLowering,
AtomicCmpXchgOpLowering,
AtomicRMWOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
CallOpLowering,

View File

@ -135,7 +135,8 @@ static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
}
/// A custom cast operation verifier.
template <typename T> static LogicalResult verifyCastOp(T op) {
template <typename T>
static LogicalResult verifyCastOp(T op) {
auto opType = op.getOperand().getType();
auto resType = op.getType();
if (!T::areCastCompatible(opType, resType))
@ -2614,6 +2615,41 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
return false;
}
//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicRMWOp op) {
if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
return op.emitOpError(
"expects the number of subscripts to be equal to memref rank");
switch (op.kind()) {
case AtomicRMWKind::addf:
case AtomicRMWKind::maxf:
case AtomicRMWKind::minf:
case AtomicRMWKind::mulf:
if (!op.value().getType().isa<FloatType>())
return op.emitOpError()
<< "with kind '" << stringifyAtomicRMWKind(op.kind())
<< "' expects a floating-point type";
break;
case AtomicRMWKind::addi:
case AtomicRMWKind::maxs:
case AtomicRMWKind::maxu:
case AtomicRMWKind::mins:
case AtomicRMWKind::minu:
case AtomicRMWKind::muli:
if (!op.value().getType().isa<IntegerType>())
return op.emitOpError()
<< "with kind '" << stringifyAtomicRMWKind(op.kind())
<< "' expects an integer type";
break;
default:
break;
}
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -858,6 +858,46 @@ module {
// -----
// CHECK-LABEL: func @atomic_rmw
func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
atomic_rmw "addi" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
atomic_rmw "maxs" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
atomic_rmw "mins" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
atomic_rmw "maxu" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
atomic_rmw "minu" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
atomic_rmw "addf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
return
}
// -----
// CHECK-LABEL: func @cmpxchg
func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
%x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: %[[init:.*]] = llvm.load %{{.*}} : !llvm<"float*">
// CHECK-NEXT: llvm.br ^bb1(%[[init]] : !llvm.float)
// CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float):
// CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float
// CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float
// CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float
// CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }">
// CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }">
// CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float)
// CHECK-NEXT: ^bb2:
return %x : f32
// CHECK-NEXT: llvm.return %[[new]]
}
// -----
// CHECK-LABEL: func @assume_alignment
func @assume_alignment(%0 : memref<4x4xf16>) {
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">

View File

@ -741,6 +741,13 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
return
}
// CHECK-LABEL: func @atomic_rmw
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
// CHECK: %{{.*}} = atomic_rmw "addf" %{{.*}}, %{{.*}}[%{{.*}}]
%x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<10xf32>) -> f32
return
}
// CHECK-LABEL: func @assume_alignment
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
func @assume_alignment(%0: memref<4x4xf16>) {

View File

@ -1039,6 +1039,30 @@ func @invalid_memref_cast() {
// -----
func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
%x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
return
}
// -----
func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
// expected-error@+1 {{expects a floating-point type}}
%x = atomic_rmw "addf" %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
return
}
// -----
func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
// expected-error@+1 {{expects an integer type}}
%x = atomic_rmw "addi" %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
return
}
// -----
// alignment is not power of 2.
func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{alignment must be power of 2}}