[MLIR] Change default builders generated by TableGen to use TypeRange for result types

- Change the default builders to use TypeRange instead of ArrayRef<Type>
- Custom builders defined in LinalgStructuredOps now conflict with the default
  separate param ones, but the default collective params one is still needed. Resolve
  this by replicating the collective param builder as a custom builder and skipping
  the generation of default builders for these ops.

Differential Revision: https://reviews.llvm.org/D87926
This commit is contained in:
Rahul Joshi 2020-09-22 20:58:30 -07:00
parent e90343ada3
commit 9744606614
5 changed files with 37 additions and 28 deletions

View File

@ -2006,8 +2006,8 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
//
// ```c++
// static void build(OpBuilder &, OperationState &odsState,
// ArrayRef<Type> resultTypes,
// ArrayRef<Value> operands,
// TypeRange resultTypes,
// ValueRange operands,
// ArrayRef<NamedAttribute> attributes);
// ```
list<OpBuilder> builders = ?;

View File

@ -81,9 +81,9 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: ::mlir::FloatAttr attr2Attr()
// CHECK: ::llvm::Optional< ::llvm::APFloat > attr2();
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value val);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount)
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount)
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions)
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount)
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount)
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions)
// CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
// CHECK: void print(::mlir::OpAsmPrinter &p);
// CHECK: ::mlir::LogicalResult verify();
@ -180,8 +180,8 @@ def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
// CHECK_LABEL: class NS_HCollectiveParamsOp :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
// Check suppression of "separate arg, separate result" build method for an op
// with single variadic arg and single variadic result (since it will be
@ -192,8 +192,8 @@ def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// Check suppression of "separate arg, collective result" build method for an op
// with single variadic arg and non variadic result (since it will be
@ -204,8 +204,8 @@ def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// Check suppression of "separate arg, collective result" build method for an op
// with single variadic arg and > 1 variadic result (since it will be
@ -217,9 +217,9 @@ def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVari
let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a);
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a);
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// Check default value of `attributes` for the `genUseOperandAsResultTypeCollectiveParamBuilder` builder
def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperandsAndResultType]> {
@ -228,8 +228,8 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands
}
// CHECK_LABEL: class NS_IOp :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
@ -241,8 +241,8 @@ def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterface
// CHECK_LABEL: class NS_JOp :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// Check that default builders can be suppressed.

View File

@ -14,7 +14,7 @@ def OpA : NS_Op<"one_normal_result_op", []> {
}
// CHECK-LABEL: void OpA::build
// CHECK: ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands
// CHECK: ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands
// CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
// CHECK-NEXT: odsState.addTypes(resultTypes);
@ -39,7 +39,7 @@ def OpC : NS_Op<"three_normal_result_op", []> {
// CHECK-NEXT: odsState.addTypes(resultType1)
// CHECK-NEXT: odsState.addTypes(z)
// CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes) {
// CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes) {
// CHECK-NEXT: assert(resultTypes.size() == 3u && "mismatched number of results");
// CHECK-NEXT: odsState.addTypes(resultTypes);
@ -67,7 +67,7 @@ def OpF : NS_Op<"one_variadic_result_op", []> {
}
// CHECK-LABEL: void OpF::build
// CHECK-SAME: ::llvm::ArrayRef<::mlir::Type> x
// CHECK-SAME: ::mlir::TypeRange x
// CHECK-NOT: assert
// CHECK: odsState.addTypes(x);
@ -78,12 +78,12 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
// CHECK-LABEL: OpG definitions
// CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::llvm::ArrayRef<::mlir::Type> y)
// CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::mlir::TypeRange y)
// CHECK-NEXT: odsState.addTypes(x);
// CHECK-NEXT: odsState.addTypes(y);
// CHECK: void OpG::build
// CHECK: ::llvm::ArrayRef<::mlir::Type> resultTypes
// CHECK: ::mlir::TypeRange resultTypes
// CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types");
// CHECK-NEXT: odsState.addTypes(resultTypes);

View File

@ -1455,8 +1455,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let skipDefaultBuilders = 1;
let builders = [ OpBuilder<
"OpBuilder &b, OperationState &result,"
"OpBuilder &b, OperationState &result, "
"ValueRange inputs, ValueRange outputBuffers",
[{{
result.addOperands(inputs);
@ -1493,6 +1494,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
TypeRange(outputBuffers),
TypeRange(initTensors),
resultTensorTypes);
}]>, OpBuilder<
"OpBuilder &b, OperationState &result, TypeRange resultTensorTypes,"
"ValueRange operands, ArrayRef<NamedAttribute> attributes = {{}",
[{{
result.addOperands(operands);
result.addAttributes(attributes);
result.addTypes(resultTensorTypes);
(void)result.addRegion();
}]>
];
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];

View File

@ -1252,7 +1252,7 @@ void OpEmitter::genCollectiveParamBuilder() {
SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes");
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
paramList.emplace_back("::mlir::ValueRange", "operands");
// Provide default value for `attributes` when its the last parameter
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
@ -1322,8 +1322,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
StringRef type = result.isVariadic() ? "::llvm::ArrayRef<::mlir::Type>"
: "::mlir::Type";
StringRef type =
result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (result.isOptional())
properties = OpMethodParameter::PP_Optional;
@ -1333,7 +1333,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
}
} break;
case TypeParamKind::Collective: {
paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes");
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
resultTypeNames.push_back("resultTypes");
} break;
}