From 54c99842079997b0fe208acdab01e540c0d81b51 Mon Sep 17 00:00:00 2001 From: Michal Terepeta Date: Thu, 18 Nov 2021 09:41:57 +0100 Subject: [PATCH] [mlir][Python] Fix generation of accessors for Optional Previously, in case there was only one `Optional` operand/result within the list, we would always return `None` from the accessor, e.g., for a single optional result we would generate: ``` return self.operation.results[0] if len(self.operation.results) > 1 else None ``` But what we really want is to return `None` only if the length of `results` is smaller than the total number of element groups (i.e., the optional operand/result is in fact missing). This commit also renames a few local variables in the generator to make the distinction between `isVariadic()` and `isVariableLength()` a bit more clear. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D113855 --- mlir/python/mlir/dialects/_linalg_ops_ext.py | 9 ------ mlir/test/mlir-tblgen/op-python-bindings.td | 2 +- mlir/test/python/dialects/linalg/ops.py | 5 +-- mlir/test/python/dialects/python_test.py | 18 +++++++++++ mlir/test/python/python_test_ops.td | 5 +++ mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 31 ++++++++++--------- 6 files changed, 42 insertions(+), 28 deletions(-) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index b7641c0a4b53..d6c57547ee16 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -36,15 +36,6 @@ class FillOp: OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, self.operation) - # TODO: self.result is None. When len(results) == 1 we expect it to be - # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug - # in the generator of _linalg_ops_gen.py where we have: - # ``` - # def result(self): - # return self.operation.results[0] \ - # if len(self.operation.results) > 1 else None - # ``` - class InitTensorOp: """Extends the linalg.init_tensor op.""" diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index becce13050a1..aa9977e047f1 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -304,7 +304,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> { // CHECK: @builtins.property // CHECK: def optional(self): - // CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None + // CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1] } diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index d788292f3424..e5b96c260eaa 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -68,10 +68,7 @@ def testFill(): @builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32)) def fill_tensor(out): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result - # TODO: FillOp.result is None. When len(results) == 1 we expect it to - # be results[0] as per _linalg_ops_gen.py. This seems like an - # orthogonal bug in the generator of _linalg_ops_gen.py. - return linalg.FillOp(output=out, value=zero).results[0] + return linalg.FillOp(output=out, value=zero).result # CHECK-LABEL: func @fill_buffer # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 2267b59cd4d7..f9da91fba4cd 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -207,3 +207,21 @@ def resultTypesDefinedByTraits(): print(implied.flt.type) # CHECK: index print(implied.index.type) + + +# CHECK-LABEL: TEST: testOptionalOperandOp +@run +def testOptionalOperandOp(): + with Context() as ctx, Location.unknown(): + test.register_python_test_dialect(ctx) + + module = Module.create() + with InsertionPoint(module.body): + + op1 = test.OptionalOperandOp(None) + # CHECK: op1.input is None: True + print(f"op1.input is None: {op1.input is None}") + + op2 = test.OptionalOperandOp(op1) + # CHECK: op2.input is None: False + print(f"op2.input is None: {op2.input is None}") diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td index 0f947e7e536b..6ee71dbf8b12 100644 --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -76,4 +76,9 @@ def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op", let results = (outs AnyType:$one, AnyType:$two, AnyType:$three); } +def OptionalOperandOp : TestOp<"optional_operand_op"> { + let arguments = (ins Optional:$input); + let results = (outs I32:$result); +} + #endif // PYTHON_TEST_OPS diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 8babff25db07..fb634a1be395 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -109,10 +109,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py( /// {1} is either 'operand' or 'result'; /// {2} is the total number of element groups; /// {3} is the position of the current group in the group list. +/// This works if we have only one variable-length group (and it's the optional +/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is +/// smaller than the total number of groups. constexpr const char *opOneOptionalTemplate = R"Py( @builtins.property def {0}(self): - return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None + return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] )Py"; /// Template for the variadic group accessor in the single variadic group case: @@ -311,7 +314,7 @@ static std::string attrSizedTraitForKind(const char *kind) { /// `operand` or `result` and is used verbatim in the emitted code. static void emitElementAccessors( const Operator &op, raw_ostream &os, const char *kind, - llvm::function_ref getNumVariadic, + llvm::function_ref getNumVariableLength, llvm::function_ref getNumElements, llvm::function_ref getElement) { @@ -326,12 +329,12 @@ static void emitElementAccessors( llvm::StringRef(kind).drop_front()); std::string attrSizedTrait = attrSizedTraitForKind(kind); - unsigned numVariadic = getNumVariadic(op); + unsigned numVariableLength = getNumVariableLength(op); - // If there is only one variadic element group, its size can be inferred from - // the total number of elements. If there are none, the generation is - // straightforward. - if (numVariadic <= 1) { + // If there is only one variable-length element group, its size can be + // inferred from the total number of elements. If there are none, the + // generation is straightforward. + if (numVariableLength <= 1) { bool seenVariableLength = false; for (int i = 0, e = getNumElements(op); i < e; ++i) { const NamedTypeConstraint &element = getElement(op, i); @@ -364,7 +367,7 @@ static void emitElementAccessors( const NamedTypeConstraint &element = getElement(op, i); if (!element.name.empty()) { os << llvm::formatv(opVariadicEqualPrefixTemplate, - sanitizeName(element.name), kind, numVariadic, + sanitizeName(element.name), kind, numVariableLength, numPrecedingSimple, numPrecedingVariadic); os << llvm::formatv(element.isVariableLength() ? opVariadicEqualVariadicTemplate @@ -414,20 +417,20 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) { /// Emits accessors to Op operands. static void emitOperandAccessors(const Operator &op, raw_ostream &os) { - auto getNumVariadic = [](const Operator &oper) { + auto getNumVariableLengthOperands = [](const Operator &oper) { return oper.getNumVariableLengthOperands(); }; - emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, - getOperand); + emitElementAccessors(op, os, "operand", getNumVariableLengthOperands, + getNumOperands, getOperand); } /// Emits accessors Op results. static void emitResultAccessors(const Operator &op, raw_ostream &os) { - auto getNumVariadic = [](const Operator &oper) { + auto getNumVariableLengthResults = [](const Operator &oper) { return oper.getNumVariableLengthResults(); }; - emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, - getResult); + emitElementAccessors(op, os, "result", getNumVariableLengthResults, + getNumResults, getResult); } /// Emits accessors to Op attributes.