[mlir] Add an AccessGroup attribute to load/store LLVM dialect ops and generate the access_group LLVM metadata.

This also includes LLVM dialect ops created from intrinsics.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D97944
This commit is contained in:
Arpith C. Jacob 2021-03-04 18:12:56 +01:00 committed by Alex Zinenko
parent d0eb25a643
commit 4e393350c5
9 changed files with 152 additions and 10 deletions

View File

@ -35,6 +35,7 @@ def LLVM_Dialect : Dialect {
static StringRef getLoopAttrName() { return "llvm.loop"; }
static StringRef getParallelAccessAttrName() { return "parallel_access"; }
static StringRef getLoopOptionsAttrName() { return "options"; }
static StringRef getAccessGroupsAttrName() { return "access_groups"; }
/// Verifies if the given string is a well-formed data layout descriptor.
/// Uses `reportError` to report errors.
@ -247,7 +248,8 @@ def LLVM_IntrPatterns {
// `llvm::Intrinsic` enum; one usually wants these to be related.
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<OpTrait> traits, int numResults>
list<OpTrait> traits, int numResults,
bit requiresAccessGroup = 0>
: LLVM_OpBase<dialect, opName, traits>,
Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
string resultPattern = !if(!gt(numResults, 1),
@ -264,19 +266,21 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
overloadedOperands>.lst), ", ") # [{
});
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
}] # !if(!gt(numResults, 0), "$res = ", "")
# [{builder.CreateCall(fn, operands);
}];
}] # [{auto *inst = builder.CreateCall(fn, operands);
}] # !if(!gt(requiresAccessGroup, 0),
"moduleTranslation.setAccessGroupsMetadata(op, inst);",
"(void) inst;")
# !if(!gt(numResults, 0), "$res = inst;", "");
}
// Base class for LLVM intrinsic operations, should not be used directly. Places
// the intrinsic into the LLVM dialect and prefixes its name with "intr.".
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<OpTrait> traits,
int numResults>
int numResults, bit requiresAccessGroup = 0>
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
overloadedResults, overloadedOperands, traits,
numResults>;
numResults, requiresAccessGroup>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".

View File

@ -287,6 +287,10 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
}
}];
code setAccessGroupsMetadataCode = [{
moduleTranslation.setAccessGroupsMetadata(op, inst);
}];
}
// Memory-related operations.
@ -326,12 +330,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]>,
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
let arguments = (ins LLVM_PointerTo<LLVM_LoadableType>:$addr,
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
let results = (outs LLVM_Type:$res);
string llvmBuilder = [{
auto *inst = builder.CreateLoad($addr, $volatile_);
}] # setAlignmentCode # setNonTemporalMetadataCode # [{
}] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode # [{
$res = inst;
}];
let builders = [
@ -346,16 +351,18 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
CArg<"bool", "false">:$isNonTemporal)>];
let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }];
let verifier = [{ return ::verify(*this); }];
}
def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
let arguments = (ins LLVM_LoadableType:$value,
LLVM_PointerTo<LLVM_LoadableType>:$addr,
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
string llvmBuilder = [{
auto *inst = builder.CreateStore($value, $addr, $volatile_);
}] # setAlignmentCode # setNonTemporalMetadataCode;
}] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode;
let builders = [
OpBuilder<(ins "Value":$value, "Value":$addr,
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
@ -363,6 +370,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
];
let parser = [{ return parseStoreOp(parser, result); }];
let printer = [{ printStoreOp(p, *this); }];
let verifier = [{ return ::verify(*this); }];
}
// Casts.

View File

@ -128,6 +128,9 @@ public:
"attempting to map loop options that was already mapped");
}
// Sets LLVM metadata for memory operations that are in a parallel loop.
void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
/// Converts the type from MLIR LLVM dialect to LLVM.
llvm::Type *convertType(Type type);

View File

@ -404,6 +404,34 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
// Builder, printer and parser for for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
static LogicalResult verifyAccessGroups(Operation *op) {
if (Attribute attribute =
op->getAttr(LLVMDialect::getAccessGroupsAttrName())) {
// The attribute is already verified to be a symbol ref array attribute via
// a constraint in the operation definition.
for (SymbolRefAttr accessGroupRef :
attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
StringRef metadataName = accessGroupRef.getRootReference();
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
op->getParentOp(), metadataName);
if (!metadataOp)
return op->emitOpError() << "expected '" << accessGroupRef
<< "' to reference a metadata op";
StringRef accessGroupName = accessGroupRef.getLeafReference();
Operation *accessGroupOp =
SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
if (!accessGroupOp)
return op->emitOpError() << "expected '" << accessGroupRef
<< "' to reference an access_group op";
}
}
return success();
}
static LogicalResult verify(LoadOp op) {
return verifyAccessGroups(op.getOperation());
}
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal) {
@ -462,6 +490,10 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
// Builder, printer and parser for LLVM::StoreOp.
//===----------------------------------------------------------------------===//
static LogicalResult verify(StoreOp op) {
return verifyAccessGroups(op.getOperation());
}
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal) {

View File

@ -656,6 +656,27 @@ LogicalResult ModuleTranslation::createAccessGroupMetadata() {
return success();
}
void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
llvm::Instruction *inst) {
auto accessGroups =
op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
if (accessGroups && !accessGroups.empty()) {
llvm::Module *module = inst->getModule();
SmallVector<llvm::Metadata *> metadatas;
for (SymbolRefAttr accessGroupRef :
accessGroups.getAsRange<SymbolRefAttr>())
metadatas.push_back(getAccessGroup(*op, accessGroupRef));
llvm::MDNode *unionMD = nullptr;
if (metadatas.size() == 1)
unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
else if (metadatas.size() >= 2)
unionMD = llvm::MDNode::get(module->getContext(), metadatas);
inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
}
}
llvm::Type *ModuleTranslation::convertType(Type type) {
return typeTranslator.translateType(type);
}

View File

@ -796,3 +796,39 @@ module {
llvm.return
}
}
// -----
module {
llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
// expected-error@below {{attribute 'access_groups' failed to satisfy constraint: symbol ref array attribute}}
%0 = llvm.load %arg0 { "access_groups" = "test" } : !llvm.ptr<i32>
llvm.return
}
}
// -----
module {
llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
// expected-error@below {{expected '@func1' to reference a metadata op}}
%0 = llvm.load %arg0 { "access_groups" = [@func1] } : !llvm.ptr<i32>
llvm.return
}
llvm.func @func1() {
llvm.return
}
}
// -----
module {
llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
// expected-error@below {{expected '@metadata' to reference an access_group op}}
%0 = llvm.load %arg0 { "access_groups" = [@metadata] } : !llvm.ptr<i32>
llvm.return
}
llvm.metadata @metadata {
llvm.return
}
}

View File

@ -1483,6 +1483,7 @@ module {
llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
^bb4:
%3 = llvm.add %1, %arg2 : i32
// CHECK: = load i32, i32* %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]]
%5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr<i32>
// CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]]
llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
@ -1504,3 +1505,4 @@ module {
// CHECK: ![[UNROLL_DISABLE_NODE]] = !{!"llvm.loop.unroll.disable", i1 true}
// CHECK: ![[LICM_DISABLE_NODE]] = !{!"llvm.licm.disable", i1 true}
// CHECK: ![[INTERLEAVE_NODE]] = !{!"llvm.loop.interleave.count", i32 1}
// CHECK: ![[ACCESS_GROUPS_NODE]] = !{![[GROUP_NODE1]], ![[GROUP_NODE2]]}

View File

@ -23,11 +23,33 @@
// It has no side effects.
// CHECK: [NoSideEffect]
// It has a result.
// CHECK: 1>
// CHECK: 1,
// It does not require an access group.
// CHECK: 0>
// CHECK: Arguments<(ins LLVM_Type, LLVM_Type
//---------------------------------------------------------------------------//
// This checks that we can define an op that takes in an access group metadata.
//
// RUN: cat %S/../../../llvm/include/llvm/IR/Intrinsics.td \
// RUN: | grep -v "llvm/IR/Intrinsics" \
// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=ptrmask --llvmir-intrinsics-access-group-regexp=ptrmask \
// RUN: | FileCheck --check-prefix=GROUPS %s
// GROUPS-LABEL: def LLVM_ptrmask
// GROUPS: LLVM_IntrOp<"ptrmask
// It has no side effects.
// GROUPS: [NoSideEffect]
// It has a result.
// GROUPS: 1,
// It requires generation of an access group LLVM metadata.
// GROUPS: 1>
// It has an access group attribute.
// GROUPS: OptionalAttr<SymbolRefArrayAttr>:$access_groups
//---------------------------------------------------------------------------//
// This checks that the ODS we produce can be consumed by MLIR tablegen. We only
// make sure the entire process does not fail and produces some C++. The shape
// of this C++ code is tested by ODS tests.

View File

@ -17,6 +17,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MachineValueType.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
@ -37,6 +38,12 @@ static llvm::cl::opt<std::string>
"are planning to emit"),
llvm::cl::init("LLVM_IntrOp"), llvm::cl::cat(IntrinsicGenCat));
static llvm::cl::opt<std::string> accessGroupRegexp(
"llvmir-intrinsics-access-group-regexp",
llvm::cl::desc("Mark intrinsics that match the specified "
"regexp as taking an access group metadata"),
llvm::cl::cat(IntrinsicGenCat));
// Used to represent the indices of overloadable operands/results.
using IndicesTy = llvm::SmallBitVector;
@ -185,6 +192,10 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
LLVMIntrinsic intr(record);
llvm::Regex accessGroupMatcher(accessGroupRegexp);
bool requiresAccessGroup =
!accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
// Prepare strings for traits, if any.
llvm::SmallVector<llvm::StringRef, 2> traits;
if (intr.isCommutative())
@ -195,6 +206,8 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
// Prepare strings for operands.
llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(),
"LLVM_Type");
if (requiresAccessGroup)
operands.push_back("OptionalAttr<SymbolRefArrayAttr>:$access_groups");
// Emit the definition.
os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass
@ -204,7 +217,8 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(traits, os);
os << ", " << intr.getNumResults() << ">, Arguments<(ins"
os << ", " << intr.getNumResults() << ", "
<< (requiresAccessGroup ? "1" : "0") << ">, Arguments<(ins"
<< (operands.empty() ? "" : " ");
llvm::interleaveComma(operands, os);
os << ")>;\n\n";