[mlir][Python] Fix generation of accessors for Optional

Previously, in case there was only one `Optional` operand/result within
the list, we would always return `None` from the accessor, e.g., for a
single optional result we would generate:

```
return self.operation.results[0] if len(self.operation.results) > 1 else None
```

But what we really want is to return `None` only if the length of
`results` is smaller than the total number of element groups (i.e.,
the optional operand/result is in fact missing).

This commit also renames a few local variables in the generator to make
the distinction between `isVariadic()` and `isVariableLength()` a bit
more clear.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D113855
This commit is contained in:
Michal Terepeta 2021-11-18 09:41:57 +01:00 committed by Alex Zinenko
parent b10562612f
commit 54c9984207
6 changed files with 42 additions and 28 deletions

View File

@ -36,15 +36,6 @@ class FillOp:
OpView.__init__(self, op)
linalgDialect = Context.current.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, self.operation)
# TODO: self.result is None. When len(results) == 1 we expect it to be
# results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
# in the generator of _linalg_ops_gen.py where we have:
# ```
# def result(self):
# return self.operation.results[0] \
# if len(self.operation.results) > 1 else None
# ```
class InitTensorOp:
"""Extends the linalg.init_tensor op."""

View File

@ -304,7 +304,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: @builtins.property
// CHECK: def optional(self):
// CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
}

View File

@ -68,10 +68,7 @@ def testFill():
@builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
def fill_tensor(out):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
# TODO: FillOp.result is None. When len(results) == 1 we expect it to
# be results[0] as per _linalg_ops_gen.py. This seems like an
# orthogonal bug in the generator of _linalg_ops_gen.py.
return linalg.FillOp(output=out, value=zero).results[0]
return linalg.FillOp(output=out, value=zero).result
# CHECK-LABEL: func @fill_buffer
# CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32>

View File

@ -207,3 +207,21 @@ def resultTypesDefinedByTraits():
print(implied.flt.type)
# CHECK: index
print(implied.index.type)
# CHECK-LABEL: TEST: testOptionalOperandOp
@run
def testOptionalOperandOp():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
op1 = test.OptionalOperandOp(None)
# CHECK: op1.input is None: True
print(f"op1.input is None: {op1.input is None}")
op2 = test.OptionalOperandOp(op1)
# CHECK: op2.input is None: False
print(f"op2.input is None: {op2.input is None}")

View File

@ -76,4 +76,9 @@ def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op",
let results = (outs AnyType:$one, AnyType:$two, AnyType:$three);
}
def OptionalOperandOp : TestOp<"optional_operand_op"> {
let arguments = (ins Optional<AnyType>:$input);
let results = (outs I32:$result);
}
#endif // PYTHON_TEST_OPS

View File

@ -109,10 +109,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
/// This works if we have only one variable-length group (and it's the optional
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
def {0}(self):
return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
)Py";
/// Template for the variadic group accessor in the single variadic group case:
@ -311,7 +314,7 @@ static std::string attrSizedTraitForKind(const char *kind) {
/// `operand` or `result` and is used verbatim in the emitted code.
static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
@ -326,12 +329,12 @@ static void emitElementAccessors(
llvm::StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
unsigned numVariadic = getNumVariadic(op);
unsigned numVariableLength = getNumVariableLength(op);
// If there is only one variadic element group, its size can be inferred from
// the total number of elements. If there are none, the generation is
// straightforward.
if (numVariadic <= 1) {
// If there is only one variable-length element group, its size can be
// inferred from the total number of elements. If there are none, the
// generation is straightforward.
if (numVariableLength <= 1) {
bool seenVariableLength = false;
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
@ -364,7 +367,7 @@ static void emitElementAccessors(
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
os << llvm::formatv(opVariadicEqualPrefixTemplate,
sanitizeName(element.name), kind, numVariadic,
sanitizeName(element.name), kind, numVariableLength,
numPrecedingSimple, numPrecedingVariadic);
os << llvm::formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
@ -414,20 +417,20 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
auto getNumVariadic = [](const Operator &oper) {
auto getNumVariableLengthOperands = [](const Operator &oper) {
return oper.getNumVariableLengthOperands();
};
emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
getOperand);
emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
getNumOperands, getOperand);
}
/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
auto getNumVariadic = [](const Operator &oper) {
auto getNumVariableLengthResults = [](const Operator &oper) {
return oper.getNumVariableLengthResults();
};
emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
getResult);
emitElementAccessors(op, os, "result", getNumVariableLengthResults,
getNumResults, getResult);
}
/// Emits accessors to Op attributes.