[MLIR] LLVM dialect: Add llvm.atomicrmw

Summary:
This op is the counterpart to LLVM's atomicrmw instruction. Note that
volatile and syncscope attributes are not yet supported.

This will be useful for upcoming parallel versions of `affine.for` and generally
for reduction-like semantics.

Differential Revision: https://reviews.llvm.org/D72741
This commit is contained in:
Frank Laub 2020-01-17 21:09:53 +01:00 committed by Alex Zinenko
parent 37e2560d3d
commit 60a0c612df
6 changed files with 292 additions and 0 deletions

View File

@ -723,4 +723,56 @@ def LLVM_Prefetch : LLVM_ZeroResultOp<"intr.prefetch">,
}];
}
def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0>;
def AtomicBinOpAdd : I64EnumAttrCase<"add", 1>;
def AtomicBinOpSub : I64EnumAttrCase<"sub", 2>;
def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3>;
def AtomicBinOpNand : I64EnumAttrCase<"nand", 4>;
def AtomicBinOpOr : I64EnumAttrCase<"_or", 5>;
def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6>;
def AtomicBinOpMax : I64EnumAttrCase<"max", 7>;
def AtomicBinOpMin : I64EnumAttrCase<"min", 8>;
def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9>;
def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10>;
def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11>;
def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12>;
def AtomicBinOp : I64EnumAttr<
"AtomicBinOp",
"llvm.atomicrmw binary operations",
[AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd,
AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax,
AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd,
AtomicBinOpFSub]> {
let cppNamespace = "::mlir::LLVM";
}
def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0>;
def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1>;
def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2>;
def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 4>;
def AtomicOrderingRelease : I64EnumAttrCase<"release", 5>;
def AtomicOrderingAcquireRelease : I64EnumAttrCase<"acq_rel", 6>;
def AtomicOrderingSequentiallyConsistent : I64EnumAttrCase<"seq_cst", 7>;
def AtomicOrdering : I64EnumAttr<
"AtomicOrdering",
"Atomic ordering for LLVM's memory model",
[AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic,
AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcquireRelease,
AtomicOrderingSequentiallyConsistent]> {
let cppNamespace = "::mlir::LLVM";
}
def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">,
Arguments<(ins AtomicBinOp:$bin_op, LLVM_Type:$ptr, LLVM_Type:$val,
AtomicOrdering:$ordering)>,
Results<(outs LLVM_Type:$res)> {
let llvmBuilder = [{
$res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val,
getLLVMAtomicOrdering($ordering));
}];
let parser = [{ return parseAtomicRMWOp(parser, result); }];
let printer = [{ printAtomicRMWOp(p, *this); }];
let verifier = "return ::verify(*this);";
}
#endif // LLVMIR_OPS

View File

@ -1368,6 +1368,107 @@ static LogicalResult verify(LLVM::NullOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// Printer, parser and verifier for LLVM::AtomicRMWOp.
//===----------------------------------------------------------------------===//
static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
p << op.getOperationName() << " ";
p << '"' << stringifyAtomicBinOp(op.bin_op()) << "\" ";
p << '"' << stringifyAtomicOrdering(op.ordering()) << "\" ";
p << op.ptr() << ", " << op.val();
p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"});
p << " : (" << op.ptr().getType() << ", " << op.val().getType() << ") -> "
<< op.res().getType();
}
// <operation> ::= `llvm.atomicrmw` string-literal string-literal
// ssa-use `,` ssa-use attribute-dict? `:` type
static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
Type type;
StringAttr binOp, ordering;
llvm::SMLoc binOpLoc, orderingLoc, trailingTypeLoc;
OpAsmParser::OperandType ptr, val;
if (parser.getCurrentLocation(&binOpLoc) ||
parser.parseAttribute(binOp, "bin_op", result.attributes) ||
parser.getCurrentLocation(&orderingLoc) ||
parser.parseAttribute(ordering, "ordering", result.attributes) ||
parser.parseOperand(ptr) || parser.parseComma() ||
parser.parseOperand(val) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();
// Extract the result type from the trailing function type.
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType || funcType.getNumInputs() != 2 ||
funcType.getNumResults() != 1)
return parser.emitError(
trailingTypeLoc,
"expected trailing function type with two arguments and one result");
if (parser.resolveOperand(ptr, funcType.getInput(0), result.operands) ||
parser.resolveOperand(val, funcType.getInput(1), result.operands))
return failure();
// Replace the string attribute `bin_op` with an integer attribute.
auto binOpKind = symbolizeAtomicBinOp(binOp.getValue());
if (!binOpKind) {
return parser.emitError(binOpLoc)
<< "'" << binOp.getValue()
<< "' is an incorrect value of the 'bin_op' attribute";
}
auto binOpValue = static_cast<int64_t>(binOpKind.getValue());
auto binOpAttr = parser.getBuilder().getI64IntegerAttr(binOpValue);
result.attributes[0].second = binOpAttr;
// Replace the string attribute `ordering` with an integer attribute.
auto orderingKind = symbolizeAtomicOrdering(ordering.getValue());
if (!orderingKind) {
return parser.emitError(orderingLoc)
<< "'" << ordering.getValue()
<< "' is an incorrect value of the 'ordering' attribute";
}
auto orderingValue = static_cast<int64_t>(orderingKind.getValue());
auto orderingAttr = parser.getBuilder().getI64IntegerAttr(orderingValue);
result.attributes[1].second = orderingAttr;
result.addTypes(funcType.getResults());
return success();
}
static LogicalResult verify(AtomicRMWOp op) {
auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
if (!ptrType.isPointerTy())
return op.emitOpError("expected LLVM IR pointer type for operand #0");
auto valType = op.val().getType().cast<LLVM::LLVMType>();
if (valType != ptrType.getPointerElementTy())
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for operand #1");
auto resType = op.res().getType().cast<LLVM::LLVMType>();
if (resType != valType)
return op.emitOpError(
"expected LLVM IR result type to match type for operand #1");
if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
if (!valType.getUnderlyingType()->isFloatingPointTy())
return op.emitOpError("expected LLVM IR floating point type");
} else if (op.bin_op() == AtomicBinOp::xchg) {
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
!valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
!valType.getUnderlyingType()->isHalfTy() && !valType.isFloatTy() &&
!valType.isDoubleTy())
return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
!valType.isIntegerTy(32) && !valType.isIntegerTy(64))
return op.emitOpError("expected LLVM IR integer type");
}
return success();
}
//===----------------------------------------------------------------------===//
// LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//

View File

@ -211,6 +211,58 @@ static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
llvm_unreachable("incorrect comparison predicate");
}
static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
switch (op) {
case LLVM::AtomicBinOp::xchg:
return llvm::AtomicRMWInst::BinOp::Xchg;
case LLVM::AtomicBinOp::add:
return llvm::AtomicRMWInst::BinOp::Add;
case LLVM::AtomicBinOp::sub:
return llvm::AtomicRMWInst::BinOp::Sub;
case LLVM::AtomicBinOp::_and:
return llvm::AtomicRMWInst::BinOp::And;
case LLVM::AtomicBinOp::nand:
return llvm::AtomicRMWInst::BinOp::Nand;
case LLVM::AtomicBinOp::_or:
return llvm::AtomicRMWInst::BinOp::Or;
case LLVM::AtomicBinOp::_xor:
return llvm::AtomicRMWInst::BinOp::Xor;
case LLVM::AtomicBinOp::max:
return llvm::AtomicRMWInst::BinOp::Max;
case LLVM::AtomicBinOp::min:
return llvm::AtomicRMWInst::BinOp::Min;
case LLVM::AtomicBinOp::umax:
return llvm::AtomicRMWInst::BinOp::UMax;
case LLVM::AtomicBinOp::umin:
return llvm::AtomicRMWInst::BinOp::UMin;
case LLVM::AtomicBinOp::fadd:
return llvm::AtomicRMWInst::BinOp::FAdd;
case LLVM::AtomicBinOp::fsub:
return llvm::AtomicRMWInst::BinOp::FSub;
}
llvm_unreachable("incorrect atomic binary operator");
}
static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
switch (ordering) {
case LLVM::AtomicOrdering::not_atomic:
return llvm::AtomicOrdering::NotAtomic;
case LLVM::AtomicOrdering::unordered:
return llvm::AtomicOrdering::Unordered;
case LLVM::AtomicOrdering::monotonic:
return llvm::AtomicOrdering::Monotonic;
case LLVM::AtomicOrdering::acquire:
return llvm::AtomicOrdering::Acquire;
case LLVM::AtomicOrdering::release:
return llvm::AtomicOrdering::Release;
case LLVM::AtomicOrdering::acq_rel:
return llvm::AtomicOrdering::AcquireRelease;
case LLVM::AtomicOrdering::seq_cst:
return llvm::AtomicOrdering::SequentiallyConsistent;
}
llvm_unreachable("incorrect atomic ordering");
}
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`. LLVM IR Builder does not have a generic interface so
/// this has to be a long chain of `if`s calling different functions with a

View File

@ -393,4 +393,51 @@ llvm.func @recursive_type(%a : !llvm<"%a = type { %a* }">) ->
llvm.return %a : !llvm<"%a = type { %a* }">
}
// -----
// CHECK-LABEL: @atomicrmw_expected_ptr
func @atomicrmw_expected_ptr(%f32 : !llvm.float) {
// expected-error@+1 {{expected LLVM IR pointer type for operand #0}}
%0 = llvm.atomicrmw "fadd" "unordered" %f32, %f32 : (!llvm.float, !llvm.float) -> !llvm.float
llvm.return
}
// -----
// CHECK-LABEL: @atomicrmw_mismatched_operands
func @atomicrmw_mismatched_operands(%f32_ptr : !llvm<"float*">, %i32 : !llvm.i32) {
// expected-error@+1 {{expected LLVM IR element type for operand #0 to match type for operand #1}}
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %i32 : (!llvm<"float*">, !llvm.i32) -> !llvm.float
llvm.return
}
// -----
// CHECK-LABEL: @atomicrmw_mismatched_result
func @atomicrmw_mismatched_operands(%f32_ptr : !llvm<"float*">, %f32 : !llvm.float) {
// expected-error@+1 {{expected LLVM IR result type to match type for operand #1}}
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.i32
llvm.return
}
// -----
// CHECK-LABEL: @atomicrmw_expected_float
func @atomicrmw_expected_float(%i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) {
// expected-error@+1 {{expected LLVM IR floating point type}}
%0 = llvm.atomicrmw "fadd" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
llvm.return
}
// -----
// CHECK-LABEL: @atomicrmw_unexpected_xchg_type
func @atomicrmw_xchg_type(%i1_ptr : !llvm<"i1*">, %i1 : !llvm.i1) {
// expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}}
%0 = llvm.atomicrmw "xchg" "unordered" %i1_ptr, %i1 : (!llvm<"i1*">, !llvm.i1) -> !llvm.i1
llvm.return
}
// -----
// CHECK-LABEL: @atomicrmw_expected_int
func @atomicrmw_expected_int(%f32_ptr : !llvm<"float*">, %f32 : !llvm.float) {
// expected-error@+1 {{expected LLVM IR integer type}}
%0 = llvm.atomicrmw "max" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
llvm.return
}

View File

@ -218,3 +218,10 @@ func @null() {
%1 = llvm.mlir.null : !llvm<"{void(i32, void()*)*, i64}*">
llvm.return
}
// CHECK-LABEL: @atomics
func @atomics(%arg0 : !llvm<"float*">, %arg1 : !llvm.float) {
// CHECK: llvm.atomicrmw "fadd" "unordered" %{{.*}}, %{{.*}} : (!llvm<"float*">, !llvm.float) -> !llvm.float
%0 = llvm.atomicrmw "fadd" "unordered" %arg0, %arg1 : (!llvm<"float*">, !llvm.float) -> !llvm.float
llvm.return
}

View File

@ -1086,3 +1086,36 @@ llvm.func @elements_constant_3d_array() -> !llvm<"[2 x [2 x [2 x i32]]]"> {
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm<"[2 x [2 x [2 x i32]]]">
llvm.return %0 : !llvm<"[2 x [2 x [2 x i32]]]">
}
// CHECK-LABEL: @atomics
llvm.func @atomics(
%f32_ptr : !llvm<"float*">, %f32 : !llvm.float,
%i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) -> !llvm.float {
// CHECK: atomicrmw fadd float* %{{.*}}, float %{{.*}} unordered
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
// CHECK: atomicrmw fsub float* %{{.*}}, float %{{.*}} unordered
%1 = llvm.atomicrmw "fsub" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
// CHECK: atomicrmw xchg float* %{{.*}}, float %{{.*}} monotonic
%2 = llvm.atomicrmw "xchg" "monotonic" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
// CHECK: atomicrmw add i32* %{{.*}}, i32 %{{.*}} acquire
%3 = llvm.atomicrmw "add" "acquire" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw sub i32* %{{.*}}, i32 %{{.*}} release
%4 = llvm.atomicrmw "sub" "release" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw and i32* %{{.*}}, i32 %{{.*}} acq_rel
%5 = llvm.atomicrmw "_and" "acq_rel" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw nand i32* %{{.*}}, i32 %{{.*}} seq_cst
%6 = llvm.atomicrmw "nand" "seq_cst" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw or i32* %{{.*}}, i32 %{{.*}} unordered
%7 = llvm.atomicrmw "_or" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw xor i32* %{{.*}}, i32 %{{.*}} unordered
%8 = llvm.atomicrmw "_xor" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw max i32* %{{.*}}, i32 %{{.*}} unordered
%9 = llvm.atomicrmw "max" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw min i32* %{{.*}}, i32 %{{.*}} unordered
%10 = llvm.atomicrmw "min" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw umax i32* %{{.*}}, i32 %{{.*}} unordered
%11 = llvm.atomicrmw "umax" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
// CHECK: atomicrmw umin i32* %{{.*}}, i32 %{{.*}} unordered
%12 = llvm.atomicrmw "umin" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
llvm.return %0 : !llvm.float
}