[mlir][ODS] Add support for optional operands and results with a new Optional directive.

Summary: This revision adds support for specifying operands or results as "optional". This is a special case of variadic where the number of elements is either 0 or 1. Operands and results of this kind will have accessors generated using Value instead of the range types, making it more natural to interface with.

Differential Revision: https://reviews.llvm.org/D77863
This commit is contained in:
River Riddle 2020-04-10 14:11:45 -07:00
parent 2a922da3a9
commit aba1acc89c
21 changed files with 379 additions and 121 deletions

View File

@ -221,11 +221,28 @@ To declare a variadic operand, wrap the `TypeConstraint` for the operand with
Normally operations have no variadic operands or just one variadic operand. For
the latter case, it is easy to deduce which dynamic operands are for the static
variadic operand definition. But if an operation has more than one variadic
operands, it would be impossible to attribute dynamic operands to the
corresponding static variadic operand definitions without further information
from the operation. Therefore, the `SameVariadicOperandSize` trait is needed to
indicate that all variadic operands have the same number of dynamic values.
variadic operand definition. Though, if an operation has more than one variable
length operands (either optional or variadic), it would be impossible to
attribute dynamic operands to the corresponding static variadic operand
definitions without further information from the operation. Therefore, either
the `SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to
indicate that all variable length operands have the same number of dynamic
values.
#### Optional operands
To declare an optional operand, wrap the `TypeConstraint` for the operand with
`Optional<...>`.
Normally operations have no optional operands or just one optional operand. For
the latter case, it is easy to deduce which dynamic operands are for the static
operand definition. Though, if an operation has more than one variable length
operands (either optional or variadic), it would be impossible to attribute
dynamic operands to the corresponding static variadic operand definitions
without further information from the operation. Therefore, either the
`SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to
indicate that all variable length operands have the same number of dynamic
values.
#### Optional attributes
@ -693,7 +710,7 @@ information. An optional group is defined by wrapping a set of elements within
the group.
- Any attribute variable may be used, but only optional attributes can be
marked as the anchor.
- Only variadic, i.e. optional, operand arguments can be used.
- Only variadic or optional operand arguments can be used.
- The operands to a type directive must be defined within the optional
group.

View File

@ -297,13 +297,17 @@ class DialectType<Dialect d, Pred condition, string descr = ""> :
}
// A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results. An op can declare no
// more than one variadic operand/result, and that operand/result must be the
// last one in the operand/result list.
// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate, type.description> {
Type baseType = type;
}
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.description> {
Type baseType = type;
}
// A type that can be constructed using MLIR::Builder.
// Note that this does not "inherit" from Type because it would require
// duplicating Type subclasses for buildable and non-buildable cases to avoid

View File

@ -621,6 +621,9 @@ public:
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;
/// Parse an optional type.
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
/// Parse a type of a specific type.
template <typename TypeT>
ParseResult parseType(TypeT &result) {

View File

@ -43,8 +43,14 @@ struct NamedAttribute {
struct NamedTypeConstraint {
// Returns true if this operand/result has constraint to be satisfied.
bool hasPredicate() const;
// Returns true if this is an optional type constraint. This is a special case
// of variadic for 0 or 1 type.
bool isOptional() const;
// Returns true if this operand/result is variadic.
bool isVariadic() const;
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }
llvm::StringRef name;
TypeConstraint constraint;

View File

@ -88,7 +88,7 @@ public:
using value_iterator = NamedTypeConstraint *;
using value_range = llvm::iterator_range<value_iterator>;
// Returns true if this op has variadic operands or results.
// Returns true if this op has variable length operands or results.
bool isVariadic() const;
// Returns true if default builders should not be generated.
@ -115,8 +115,8 @@ public:
// Returns the `index`-th result's decorators.
var_decorator_range getResultDecorators(int index) const;
// Returns the number of variadic results in this operation.
unsigned getNumVariadicResults() const;
// Returns the number of variable length results in this operation.
unsigned getNumVariableLengthResults() const;
// Op attribute iterators.
using attribute_iterator = const NamedAttribute *;
@ -142,7 +142,7 @@ public:
}
// Returns the number of variadic operands in this operation.
unsigned getNumVariadicOperands() const;
unsigned getNumVariableLengthOperands() const;
// Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); }

View File

@ -34,9 +34,16 @@ public:
static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }
// Returns true if this is an optional type constraint.
bool isOptional() const;
// Returns true if this is a variadic type constraint.
bool isVariadic() const;
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }
// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> getBuilderCall() const;

View File

@ -227,6 +227,9 @@ public:
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
/// Optionally parse a type.
OptionalParseResult parseOptionalType(Type &type);
/// Parse an arbitrary type.
Type parseType();
@ -899,6 +902,31 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
// Type Parsing
//===----------------------------------------------------------------------===//
/// Optionally parse a type.
OptionalParseResult Parser::parseOptionalType(Type &type) {
// There are many different starting tokens for a type, check them here.
switch (getToken().getKind()) {
case Token::l_paren:
case Token::kw_memref:
case Token::kw_tensor:
case Token::kw_complex:
case Token::kw_tuple:
case Token::kw_vector:
case Token::inttype:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_f32:
case Token::kw_f64:
case Token::kw_index:
case Token::kw_none:
case Token::exclamation_identifier:
return failure(!(type = parseType()));
default:
return llvm::None;
}
}
/// Parse an arbitrary type.
///
/// type ::= function-type
@ -4509,6 +4537,11 @@ public:
return failure(!(result = parser.parseType()));
}
/// Parse an optional type.
OptionalParseResult parseOptionalType(Type &result) override {
return parser.parseOptionalType(result);
}
/// Parse an arrow followed by a type list.
ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
if (parseArrow() || parser.parseFunctionResultTypes(result))

View File

@ -15,6 +15,10 @@ bool tblgen::NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull();
}
bool tblgen::NamedTypeConstraint::isOptional() const {
return constraint.isOptional();
}
bool tblgen::NamedTypeConstraint::isVariadic() const {
return constraint.isVariadic();
}

View File

@ -81,10 +81,6 @@ StringRef tblgen::Operator::getExtraClassDeclaration() const {
const llvm::Record &tblgen::Operator::getDef() const { return def; }
bool tblgen::Operator::isVariadic() const {
return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0;
}
bool tblgen::Operator::skipDefaultBuilders() const {
return def.getValueAsBit("skipDefaultBuilders");
}
@ -119,16 +115,16 @@ auto tblgen::Operator::getResultDecorators(int index) const
return *result->getValueAsListInit("decorators");
}
unsigned tblgen::Operator::getNumVariadicResults() const {
return std::count_if(
results.begin(), results.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
unsigned tblgen::Operator::getNumVariableLengthResults() const {
return llvm::count_if(results, [](const NamedTypeConstraint &c) {
return c.constraint.isVariableLength();
});
}
unsigned tblgen::Operator::getNumVariadicOperands() const {
return std::count_if(
operands.begin(), operands.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
unsigned tblgen::Operator::getNumVariableLengthOperands() const {
return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
return c.constraint.isVariableLength();
});
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {

View File

@ -255,7 +255,7 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariadic()) {
if (operand->isVariableLength()) {
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
return std::string(repl);

View File

@ -26,6 +26,10 @@ TypeConstraint::TypeConstraint(const llvm::Record *record)
TypeConstraint::TypeConstraint(const llvm::DefInit *init)
: TypeConstraint(init->getDef()) {}
bool TypeConstraint::isOptional() const {
return def->isSubClassOf("Optional");
}
bool TypeConstraint::isVariadic() const {
return def->isSubClassOf("Variadic");
}
@ -34,7 +38,7 @@ bool TypeConstraint::isVariadic() const {
// returns None otherwise.
Optional<StringRef> TypeConstraint::getBuilderCall() const {
const llvm::Record *baseType = def;
if (isVariadic())
if (isVariableLength())
baseType = baseType->getValueAsDef("baseType");
// Check to see if this type constraint has a builder call.

View File

@ -1179,39 +1179,41 @@ def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
}
// Test various mixings of result type formatting.
class FormatResultBase<string name, string fmt> : TEST_Op<name> {
class FormatResultBase<string suffix, string fmt>
: TEST_Op<"format_result_" # suffix # "_op"> {
let results = (outs I64:$buildable_res, AnyMemRef:$result);
let assemblyFormat = fmt;
}
def FormatResultAOp : FormatResultBase<"format_result_a_op", [{
def FormatResultAOp : FormatResultBase<"a", [{
type($result) attr-dict
}]>;
def FormatResultBOp : FormatResultBase<"format_result_b_op", [{
def FormatResultBOp : FormatResultBase<"b", [{
type(results) attr-dict
}]>;
def FormatResultCOp : FormatResultBase<"format_result_c_op", [{
def FormatResultCOp : FormatResultBase<"c", [{
functional-type($buildable_res, $result) attr-dict
}]>;
// Test various mixings of operand type formatting.
class FormatOperandBase<string name, string fmt> : TEST_Op<name> {
class FormatOperandBase<string suffix, string fmt>
: TEST_Op<"format_operand_" # suffix # "_op"> {
let arguments = (ins I64:$buildable, AnyMemRef:$operand);
let assemblyFormat = fmt;
}
def FormatOperandAOp : FormatOperandBase<"format_operand_a_op", [{
def FormatOperandAOp : FormatOperandBase<"a", [{
operands `:` type(operands) attr-dict
}]>;
def FormatOperandBOp : FormatOperandBase<"format_operand_b_op", [{
def FormatOperandBOp : FormatOperandBase<"b", [{
operands `:` type($operand) attr-dict
}]>;
def FormatOperandCOp : FormatOperandBase<"format_operand_c_op", [{
def FormatOperandCOp : FormatOperandBase<"c", [{
$buildable `,` $operand `:` type(operands) attr-dict
}]>;
def FormatOperandDOp : FormatOperandBase<"format_operand_d_op", [{
def FormatOperandDOp : FormatOperandBase<"d", [{
$buildable `,` $operand `:` type($operand) attr-dict
}]>;
def FormatOperandEOp : FormatOperandBase<"format_operand_e_op", [{
def FormatOperandEOp : FormatOperandBase<"e", [{
$buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict
}]>;
@ -1220,6 +1222,25 @@ def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
let assemblyFormat = "$targets attr-dict";
}
// Test various mixings of optional operand and result type formatting.
class FormatOptionalOperandResultOpBase<string suffix, string fmt>
: TEST_Op<"format_optional_operand_result_" # suffix # "_op",
[AttrSizedOperandSegments]> {
let arguments = (ins Optional<I64>:$optional, Variadic<I64>:$variadic);
let results = (outs Optional<I64>:$optional_res);
let assemblyFormat = fmt;
}
def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{
`(` $optional `:` type($optional) `)` `:` type($optional_res)
(`[` $variadic^ `]`)? attr-dict
}]>;
def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{
(`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res)
(`[` $variadic^ `]`)? attr-dict
}]>;
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//

View File

@ -112,6 +112,16 @@ def NS_DOp : NS_Op<"op_with_two_operands", []> {
// CHECK-LABEL: NS::DOp declarations
// CHECK: OpTrait::NOperands<2>::Impl
def NS_EOp : NS_Op<"op_with_optionals", []> {
let arguments = (ins Optional<I32>:$a);
let results = (outs Optional<F32>:$b);
}
// CHECK-LABEL: NS::EOp declarations
// CHECK: Value a();
// CHECK: Value b();
// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)
// Check that default builders can be suppressed.
// ---

View File

@ -222,7 +222,7 @@ def OptionalInvalidF : TestFormat_Op<"optional_invalid_f", [{
def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{
($attr^) attr-dict
}]>, Arguments<(ins I64Attr:$attr)>;
// CHECK: error: only variadic operands can be used within an optional group
// CHECK: error: only variable length operands can be used within an optional group
def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{
($arg^) attr-dict
}]>, Arguments<(ins I64:$arg)>;
@ -327,6 +327,17 @@ def ZCoverageInvalidF : TestFormat_Op<"variable_invalid_f", [{
}]> {
let successors = (successor AnySuccessor:$successor);
}
// CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred
// CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format
def ZCoverageInvalidG : TestFormat_Op<"variable_invalid_g", [{
operands attr-dict
}]>, Arguments<(ins Optional<I64>:$operand)>;
// CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred
// CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format
def ZCoverageInvalidH : TestFormat_Op<"variable_invalid_h", [{
attr-dict
}]>, Results<(outs Optional<I64>:$result)>;
// CHECK-NOT: error
def ZCoverageValidA : TestFormat_Op<"variable_valid_a", [{
$operand type($operand) type($result) attr-dict

View File

@ -18,6 +18,10 @@ test.format_attr_dict_w_keyword attributes {attr = 10 : i64}
// CHECK: test.format_buildable_type_op %[[I64]]
%ignored = test.format_buildable_type_op %i64
//===----------------------------------------------------------------------===//
// Format results
//===----------------------------------------------------------------------===//
// CHECK: test.format_result_a_op memref<1xf64>
%ignored_a:2 = test.format_result_a_op memref<1xf64>
@ -27,6 +31,10 @@ test.format_attr_dict_w_keyword attributes {attr = 10 : i64}
// CHECK: test.format_result_c_op (i64) -> memref<1xf64>
%ignored_c:2 = test.format_result_c_op (i64) -> memref<1xf64>
//===----------------------------------------------------------------------===//
// Format operands
//===----------------------------------------------------------------------===//
// CHECK: test.format_operand_a_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64>
test.format_operand_a_op %i64, %memref : i64, memref<1xf64>
@ -42,6 +50,10 @@ test.format_operand_d_op %i64, %memref : memref<1xf64>
// CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64>
test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
//===----------------------------------------------------------------------===//
// Format successors
//===----------------------------------------------------------------------===//
"foo.successor_test_region"() ( {
^bb0:
// CHECK: test.format_successor_a_op ^bb1 {attr}
@ -57,3 +69,28 @@ test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
}) { arg_names = ["i", "j", "k"] } : () -> ()
//===----------------------------------------------------------------------===//
// Format optional operands and results
//===----------------------------------------------------------------------===//
// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) : i64
test.format_optional_operand_result_a_op(%i64 : i64) : i64
// CHECK: test.format_optional_operand_result_a_op( : ) : i64
test.format_optional_operand_result_a_op( : ) : i64
// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) :
// CHECK-NOT: i64
test.format_optional_operand_result_a_op(%i64 : i64) :
// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) : [%[[I64]], %[[I64]]]
test.format_optional_operand_result_a_op(%i64 : i64) : [%i64, %i64]
// CHECK: test.format_optional_operand_result_b_op(%[[I64]] : i64) : i64
test.format_optional_operand_result_b_op(%i64 : i64) : i64
// CHECK: test.format_optional_operand_result_b_op : i64
test.format_optional_operand_result_b_op( : ) : i64
// CHECK: test.format_optional_operand_result_b_op : i64
test.format_optional_operand_result_b_op : i64

View File

@ -16,7 +16,8 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
}
// CHECK-LABEL: OpA::verify
// CHECK: for (Value v : getODSOperands(0)) {
// CHECK: auto valueGroup0 = getODSOperands(0);
// CHECK: for (Value v : valueGroup0) {
// CHECK: if (!((v.getType().isInteger(32) || v.getType().isF32())))
def OpB : NS_Op<"op_for_And_PredOpTrait", [
@ -90,5 +91,6 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
}
// CHECK-LABEL: OpK::verify
// CHECK: for (Value v : getODSOperands(0)) {
// CHECK: auto valueGroup0 = getODSOperands(0);
// CHECK: for (Value v : valueGroup0) {
// CHECK: if (!(((v.getType().isa<TensorType>())) && (((v.getType().cast<ShapedType>().getElementType().isF32())) || ((v.getType().cast<ShapedType>().getElementType().isSignlessInteger(32))))))

View File

@ -75,7 +75,7 @@ static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
if (numOperands == 0)
return false;
const auto &operand = op.getOperand(numOperands - 1);
return operand.isVariadic() && operand.name == name;
return operand.isVariableLength() && operand.name == name;
}
// Check if `result` is a known name of a result of `op`.

View File

@ -452,7 +452,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
StringRef rangeSizeCall,
StringRef getOperandCallPattern) {
const int numOperands = op.getNumOperands();
const int numVariadicOperands = op.getNumVariadicOperands();
const int numVariadicOperands = op.getNumVariableLengthOperands();
const int numNormalOperands = numOperands - numVariadicOperands;
const auto *sameVariadicSize =
@ -493,9 +493,9 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numOperands);
for (int i = 0; i < numOperands; ++i) {
isVariadic.push_back(llvm::toStringRef(op.getOperand(i).isVariadic()));
}
for (int i = 0; i < numOperands; ++i)
isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true"
: "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
@ -511,11 +511,15 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
if (operand.name.empty())
continue;
if (operand.isVariadic()) {
if (operand.isOptional()) {
auto &m = opClass.newMethod("Value", operand.name);
m.body() << " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? Value() : *operands.begin();";
} else if (operand.isVariadic()) {
auto &m = opClass.newMethod(rangeType, operand.name);
m.body() << " return getODSOperands(" << i << ");";
} else {
auto &m = opClass.newMethod("Value ", operand.name);
auto &m = opClass.newMethod("Value", operand.name);
m.body() << " return *getODSOperands(" << i << ").begin();";
}
}
@ -534,7 +538,7 @@ void OpEmitter::genNamedOperandGetters() {
void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
const int numVariadicResults = op.getNumVariadicResults();
const int numVariadicResults = op.getNumVariableLengthResults();
const int numNormalResults = numResults - numVariadicResults;
// If we have more than one variadic results, we need more complicated logic
@ -573,9 +577,9 @@ void OpEmitter::genNamedResultGetters() {
} else {
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numResults);
for (int i = 0; i < numResults; ++i) {
isVariadic.push_back(llvm::toStringRef(op.getResult(i).isVariadic()));
}
for (int i = 0; i < numResults; ++i)
isVariadic.push_back(op.getResult(i).isVariableLength() ? "true"
: "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
@ -589,11 +593,15 @@ void OpEmitter::genNamedResultGetters() {
if (result.name.empty())
continue;
if (result.isVariadic()) {
if (result.isOptional()) {
auto &m = opClass.newMethod("Value", result.name);
m.body() << " auto results = getODSResults(" << i << ");\n"
<< " return results.empty() ? Value() : *results.begin();";
} else if (result.isVariadic()) {
auto &m = opClass.newMethod("Operation::result_range", result.name);
m.body() << " return getODSResults(" << i << ");";
} else {
auto &m = opClass.newMethod("Value ", result.name);
auto &m = opClass.newMethod("Value", result.name);
m.body() << " return *getODSResults(" << i << ").begin();";
}
}
@ -706,6 +714,8 @@ void OpEmitter::genSeparateArgParamBuilder() {
return;
case TypeParamKind::Separate:
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
if (op.getResult(i).isOptional())
body << " if (" << resultNames[i] << ")\n ";
body << " " << builderOpState << ".addTypes(" << resultNames[i]
<< ");\n";
}
@ -713,12 +723,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
case TypeParamKind::Collective:
body << " "
<< "assert(resultTypes.size() "
<< (op.getNumVariadicResults() == 0 ? "==" : ">=") << " "
<< (op.getNumResults() - op.getNumVariadicResults())
<< (op.getNumVariableLengthResults() == 0 ? "==" : ">=") << " "
<< (op.getNumResults() - op.getNumVariableLengthResults())
<< "u && \"mismatched number of results\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
return;
};
}
llvm_unreachable("unhandled TypeParamKind");
};
@ -731,7 +741,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
// Emit separate arg build with collective type, unless there is only one
// variadic result, in which case the above would have already generated
// the same build method.
if (!(op.getNumResults() == 1 && op.getResult(0).isVariadic()))
if (!(op.getNumResults() == 1 && op.getResult(0).isVariableLength()))
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
}
@ -739,7 +749,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
// If this op has a variadic result, we cannot generate this builder because
// we don't know how many results to create.
if (op.getNumVariadicResults() != 0)
if (op.getNumVariableLengthResults() != 0)
return;
int numResults = op.getNumResults();
@ -887,7 +897,7 @@ void OpEmitter::genBuilder() {
// 3. one having a stand-alone parameter for each operand and attribute,
// use the first operand or attribute's type as all result types
// to facilitate different call patterns.
if (op.getNumVariadicResults() == 0) {
if (op.getNumVariableLengthResults() == 0) {
if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
genUseOperandAsResultTypeSeparateParamBuilder();
genUseOperandAsResultTypeCollectiveParamBuilder();
@ -899,11 +909,11 @@ void OpEmitter::genBuilder() {
void OpEmitter::genCollectiveParamBuilder() {
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariadicResults();
int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariadicOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
// Signature
std::string params = std::string("Builder *, OperationState &") +
@ -972,7 +982,12 @@ void OpEmitter::buildParamList(std::string &paramList,
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
if (result.isOptional())
paramList.append(", /*optional*/Type ");
else if (result.isVariadic())
paramList.append(", ArrayRef<Type> ");
else
paramList.append(", Type ");
paramList.append(resultName);
resultTypeNames.emplace_back(std::move(resultName));
@ -1018,7 +1033,12 @@ void OpEmitter::buildParamList(std::string &paramList,
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value ");
if (operand.isOptional())
paramList.append(", /*optional*/Value ");
else if (operand.isVariadic())
paramList.append(", ValueRange ");
else
paramList.append(", Value ");
paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
@ -1076,8 +1096,10 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr) {
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
body << " " << builderOpState << ".addOperands(" << getArgumentName(op, i)
<< ");\n";
std::string argName = getArgumentName(op, i);
if (op.getOperand(i).isOptional())
body << " if (" << argName << ")\n ";
body << " " << builderOpState << ".addOperands(" << argName << ");\n";
}
// If the operation has the operand segment size attribute, add it here.
@ -1086,7 +1108,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
<< ".addAttribute(\"operand_segment_sizes\", "
"odsBuilder->getI32VectorAttr({";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
if (op.getOperand(i).isVariadic())
if (op.getOperand(i).isOptional())
body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
else if (op.getOperand(i).isVariadic())
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
else
body << "1";
@ -1160,7 +1184,7 @@ void OpEmitter::genCanonicalizerDecls() {
void OpEmitter::genFolderDecls() {
bool hasSingleResult =
op.getNumResults() == 1 && op.getNumVariadicResults() == 0;
op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
@ -1434,17 +1458,33 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
body << " unsigned index = 0; (void)index;\n";
for (auto staticValue : llvm::enumerate(values)) {
if (!staticValue.value().hasPredicate())
bool hasPredicate = staticValue.value().hasPredicate();
bool isOptional = staticValue.value().isOptional();
if (!hasPredicate && !isOptional)
continue;
// Emit a loop to check all the dynamic values in the pack.
body << formatv(" for (Value v : getODS{0}{1}s({2})) {{\n",
body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
staticValue.index());
auto constraint = staticValue.value().constraint;
// If the constraint is optional check that the value group has at most 1
// value.
if (isOptional) {
body << formatv(" if (valueGroup{0}.size() > 1)\n"
" return emitOpError(\"{1} group starting at #\") "
"<< index << \" requires 0 or 1 element, but found \" << "
"valueGroup{0}.size();\n",
staticValue.index(), valueKind);
}
// Otherwise, if there is no predicate there is nothing left to do.
if (!hasPredicate)
continue;
// Emit a loop to check all the dynamic values in the pack.
body << " for (Value v : valueGroup" << staticValue.index() << ") {\n";
auto constraint = staticValue.value().constraint;
body << " (void)v;\n"
<< " if (!("
<< tgfmt(constraint.getConditionTemplate(),
@ -1569,7 +1609,7 @@ void OpEmitter::genTraits() {
// Add result size trait.
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariadicResults();
int numVariadicResults = op.getNumVariableLengthResults();
addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
// Add successor size trait.
@ -1579,7 +1619,7 @@ void OpEmitter::genTraits() {
// Add variadic size trait and normal op traits.
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariadicOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
// Add operand size trait.
if (numVariadicOperands != 0) {

View File

@ -395,6 +395,17 @@ const char *const variadicOperandParserCode = R"(
if (parser.parseOperandList({0}Operands))
return failure();
)";
const char *const optionalOperandParserCode = R"(
{
OpAsmParser::OperandType operand;
OptionalParseResult parseResult = parser.parseOptionalOperand(operand);
if (parseResult.hasValue()) {
if (failed(*parseResult))
return failure();
{0}Operands.push_back(operand);
}
}
)";
const char *const operandParserCode = R"(
if (parser.parseOperand({0}RawOperands[0]))
return failure();
@ -407,6 +418,17 @@ const char *const variadicTypeParserCode = R"(
if (parser.parseTypeList({0}Types))
return failure();
)";
const char *const optionalTypeParserCode = R"(
{
Type optionalType;
OptionalParseResult parseResult = parser.parseOptionalType(optionalType);
if (parseResult.hasValue()) {
if (failed(*parseResult))
return failure();
{0}Types.push_back(optionalType);
}
}
)";
const char *const typeParserCode = R"(
if (parser.parseType({0}RawTypes[0]))
return failure();
@ -456,18 +478,40 @@ const char *successorParserCode = R"(
return failure();
)";
namespace {
/// The type of length for a given parse argument.
enum class ArgumentLengthKind {
/// The argument is variadic, and may contain 0->N elements.
Variadic,
/// The argument is optional, and may contain 0 or 1 elements.
Optional,
/// The argument is a single element, i.e. always represents 1 element.
Single
};
} // end anonymous namespace
/// Get the length kind for the given constraint.
static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint *var) {
if (var->isOptional())
return ArgumentLengthKind::Optional;
if (var->isVariadic())
return ArgumentLengthKind::Variadic;
return ArgumentLengthKind::Single;
}
/// Get the name used for the type list for the given type directive operand.
/// 'isVariadic' is set to true if the operand has variadic types.
static StringRef getTypeListName(Element *arg, bool &isVariadic) {
/// 'lengthKind' to the corresponding kind for the given argument.
static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
if (auto *operand = dyn_cast<OperandVariable>(arg)) {
isVariadic = operand->getVar()->isVariadic();
lengthKind = getArgumentLengthKind(operand->getVar());
return operand->getVar()->name;
}
if (auto *result = dyn_cast<ResultVariable>(arg)) {
isVariadic = result->getVar()->isVariadic();
lengthKind = getArgumentLengthKind(result->getVar());
return result->getVar()->name;
}
isVariadic = true;
lengthKind = ArgumentLengthKind::Variadic;
if (isa<OperandsDirective>(arg))
return "allOperand";
if (isa<ResultsDirective>(arg))
@ -502,7 +546,7 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
genElementParserStorage(&childElement, body);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name;
if (operand->getVar()->isVariadic()) {
if (operand->getVar()->isVariableLength()) {
body << " SmallVector<OpAsmParser::OperandType, 4> " << name
<< "Operands;\n";
} else {
@ -515,15 +559,15 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
" (void){0}OperandsLoc;\n",
name);
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
bool variadic = false;
StringRef name = getTypeListName(dir->getOperand(), variadic);
if (variadic)
ArgumentLengthKind lengthKind;
StringRef name = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind != ArgumentLengthKind::Single)
body << " SmallVector<Type, 1> " << name << "Types;\n";
else
body << llvm::formatv(" Type {0}RawTypes[1];\n", name)
<< llvm::formatv(" ArrayRef<Type> {0}Types({0}RawTypes);\n", name);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
bool ignored = false;
ArgumentLengthKind ignored;
body << " ArrayRef<Type> " << getTypeListName(dir->getInputs(), ignored)
<< "Types;\n";
body << " ArrayRef<Type> " << getTypeListName(dir->getResults(), ignored)
@ -592,9 +636,14 @@ static void genElementParser(Element *element, OpMethodBody &body,
body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
bool isVariadic = operand->getVar()->isVariadic();
body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode,
operand->getVar()->name);
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv(variadicOperandParserCode, name);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(optionalOperandParserCode, name);
else
body << formatv(operandParserCode, name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
bool isVariadic = successor->getVar()->isVariadic();
body << formatv(isVariadic ? successorListParserCode : successorParserCode,
@ -614,12 +663,16 @@ static void genElementParser(Element *element, OpMethodBody &body,
} else if (isa<SuccessorsDirective>(element)) {
body << llvm::formatv(successorListParserCode, "full");
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
bool isVariadic = false;
StringRef listName = getTypeListName(dir->getOperand(), isVariadic);
body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode,
listName);
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv(variadicTypeParserCode, listName);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(optionalTypeParserCode, listName);
else
body << formatv(typeParserCode, listName);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
bool ignored = false;
ArgumentLengthKind ignored;
body << formatv(functionalTypeParserCode,
getTypeListName(dir->getInputs(), ignored),
getTypeListName(dir->getResults(), ignored));
@ -817,7 +870,7 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
<< "builder.getI32VectorAttr({";
auto interleaveFn = [&](const NamedTypeConstraint &operand) {
// If the operand is variadic emit the parsed size.
if (operand.isVariadic())
if (operand.isVariableLength())
body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
else
body << "1";
@ -885,6 +938,10 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
if (var->isVariadic())
return body << var->name << "().getTypes()";
if (var->isOptional())
return body << llvm::formatv(
"({0}() ? ArrayRef<Type>({0}().getType()) : ArrayRef<Type>())",
var->name);
return body << "ArrayRef<Type>(" << var->name << "().getType())";
}
@ -900,11 +957,16 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
// Emit the check for the presence of the anchor element.
Element *anchor = optional->getAnchor();
if (AttributeVariable *attrVar = dyn_cast<AttributeVariable>(anchor))
body << " if (getAttr(\"" << attrVar->getVar()->name << "\")) {\n";
else
body << " if (!" << cast<OperandVariable>(anchor)->getVar()->name
<< "().empty()) {\n";
if (auto *operand = dyn_cast<OperandVariable>(anchor)) {
const NamedTypeConstraint *var = operand->getVar();
if (var->isOptional())
body << " if (" << var->name << "()) {\n";
else if (var->isVariadic())
body << " if (!" << var->name << "().empty()) {\n";
} else {
body << " if (getAttr(\""
<< cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
}
// Emit each of the elements.
for (Element &childElement : optional->getElements())
@ -945,7 +1007,12 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
else
body << " p.printAttribute(" << var->name << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
body << " p << " << operand->getVar()->name << "();\n";
if (operand->getVar()->isOptional()) {
body << " if (Value value = " << operand->getVar()->name << "())\n"
<< " p << value;\n";
} else {
body << " p << " << operand->getVar()->name << "();\n";
}
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
const NamedSuccessor *var = successor->getVar();
if (var->isVariadic())
@ -1521,14 +1588,12 @@ LogicalResult FormatParser::verifyOperands(
// Similarly to results, allow a custom builder for resolving the type if
// we aren't using the 'operands' directive.
Optional<StringRef> builder = operand.constraint.getBuilderCall();
if (!builder || (hasAllOperands && operand.isVariadic())) {
if (!builder || (hasAllOperands && operand.isVariableLength())) {
return emitErrorAndNote(
loc,
"type of operand #" + Twine(i) + ", named '" + operand.name +
"', is not buildable and a buildable " +
"type cannot be inferred",
"suggest adding a type constraint "
"to the operation or adding a "
"', is not buildable and a buildable type cannot be inferred",
"suggest adding a type constraint to the operation or adding a "
"'type($" +
operand.name + ")' directive to the " + "custom assembly format");
}
@ -1559,18 +1624,16 @@ LogicalResult FormatParser::verifyResults(
continue;
}
// If the result is not variadic, allow for the case where the type has a
// builder that we can use.
// If the result is not variable length, allow for the case where the type
// has a builder that we can use.
NamedTypeConstraint &result = op.getResult(i);
Optional<StringRef> builder = result.constraint.getBuilderCall();
if (!builder || result.constraint.isVariadic()) {
if (!builder || result.isVariableLength()) {
return emitErrorAndNote(
loc,
"type of result #" + Twine(i) + ", named '" + result.name +
"', is not buildable and a buildable " +
"type cannot be inferred",
"suggest adding a type constraint "
"to the operation or adding a "
"', is not buildable and a buildable type cannot be inferred",
"suggest adding a type constraint to the operation or adding a "
"'type($" +
result.name + ")' directive to the " + "custom assembly format");
}
@ -1842,9 +1905,9 @@ LogicalResult FormatParser::parseOptionalChildElement(
// Only optional-like(i.e. variadic) operands can be within an optional
// group.
.Case<OperandVariable>([&](OperandVariable *ele) {
if (!ele->getVar()->isVariadic())
return emitError(childLoc, "only variadic operands can be used within"
" an optional group");
if (!ele->getVar()->isVariableLength())
return emitError(childLoc, "only variable length operands can be "
"used within an optional group");
seenVariables.insert(ele->getVar());
return success();
})

View File

@ -243,7 +243,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
if (operand->isVariadic()) {
if (operand->isVariableLength()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
"variadic operand #{1} unsupported now",
op.getOperationName(), i);
@ -296,7 +296,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
// of op definition.
Constraint constraint = matcher.getAsConstraint();
if (operand->constraint != constraint) {
if (operand->isVariadic()) {
if (operand->isVariableLength()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), argIndex);

View File

@ -807,11 +807,11 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
if (valueArg->isVariadic()) {
if (valueArg->isVariableLength()) {
if (i != e - 1) {
PrintFatalError(loc,
"SPIR-V ops can have Variadic<..> argument only if "
"it's the last argument");
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
"Optional<...> arguments only if "
"it's the last argument");
}
os << tabs
<< formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
@ -829,7 +829,7 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
words, wordIndex);
os << tabs << " }\n";
os << tabs << formatv(" {0}.push_back(arg);\n", operands);
if (!valueArg->isVariadic()) {
if (!valueArg->isVariableLength()) {
os << tabs << formatv(" {0}++;\n", wordIndex);
}
operandNum++;