mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-02 10:49:22 +00:00
[mlir][Python] Fix generation of accessors for Optional
Previously, in case there was only one `Optional` operand/result within the list, we would always return `None` from the accessor, e.g., for a single optional result we would generate: ``` return self.operation.results[0] if len(self.operation.results) > 1 else None ``` But what we really want is to return `None` only if the length of `results` is smaller than the total number of element groups (i.e., the optional operand/result is in fact missing). This commit also renames a few local variables in the generator to make the distinction between `isVariadic()` and `isVariableLength()` a bit more clear. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D113855
This commit is contained in:
parent
b10562612f
commit
54c9984207
@ -36,15 +36,6 @@ class FillOp:
|
||||
OpView.__init__(self, op)
|
||||
linalgDialect = Context.current.get_dialect_descriptor("linalg")
|
||||
fill_builtin_region(linalgDialect, self.operation)
|
||||
# TODO: self.result is None. When len(results) == 1 we expect it to be
|
||||
# results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
|
||||
# in the generator of _linalg_ops_gen.py where we have:
|
||||
# ```
|
||||
# def result(self):
|
||||
# return self.operation.results[0] \
|
||||
# if len(self.operation.results) > 1 else None
|
||||
# ```
|
||||
|
||||
|
||||
class InitTensorOp:
|
||||
"""Extends the linalg.init_tensor op."""
|
||||
|
@ -304,7 +304,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
|
||||
|
||||
// CHECK: @builtins.property
|
||||
// CHECK: def optional(self):
|
||||
// CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None
|
||||
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
|
||||
|
||||
}
|
||||
|
||||
|
@ -68,10 +68,7 @@ def testFill():
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
|
||||
def fill_tensor(out):
|
||||
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
|
||||
# TODO: FillOp.result is None. When len(results) == 1 we expect it to
|
||||
# be results[0] as per _linalg_ops_gen.py. This seems like an
|
||||
# orthogonal bug in the generator of _linalg_ops_gen.py.
|
||||
return linalg.FillOp(output=out, value=zero).results[0]
|
||||
return linalg.FillOp(output=out, value=zero).result
|
||||
|
||||
# CHECK-LABEL: func @fill_buffer
|
||||
# CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
|
||||
|
@ -207,3 +207,21 @@ def resultTypesDefinedByTraits():
|
||||
print(implied.flt.type)
|
||||
# CHECK: index
|
||||
print(implied.index.type)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOptionalOperandOp
|
||||
@run
|
||||
def testOptionalOperandOp():
|
||||
with Context() as ctx, Location.unknown():
|
||||
test.register_python_test_dialect(ctx)
|
||||
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
op1 = test.OptionalOperandOp(None)
|
||||
# CHECK: op1.input is None: True
|
||||
print(f"op1.input is None: {op1.input is None}")
|
||||
|
||||
op2 = test.OptionalOperandOp(op1)
|
||||
# CHECK: op2.input is None: False
|
||||
print(f"op2.input is None: {op2.input is None}")
|
||||
|
@ -76,4 +76,9 @@ def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op",
|
||||
let results = (outs AnyType:$one, AnyType:$two, AnyType:$three);
|
||||
}
|
||||
|
||||
def OptionalOperandOp : TestOp<"optional_operand_op"> {
|
||||
let arguments = (ins Optional<AnyType>:$input);
|
||||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
#endif // PYTHON_TEST_OPS
|
||||
|
@ -109,10 +109,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
|
||||
/// {1} is either 'operand' or 'result';
|
||||
/// {2} is the total number of element groups;
|
||||
/// {3} is the position of the current group in the group list.
|
||||
/// This works if we have only one variable-length group (and it's the optional
|
||||
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
|
||||
/// smaller than the total number of groups.
|
||||
constexpr const char *opOneOptionalTemplate = R"Py(
|
||||
@builtins.property
|
||||
def {0}(self):
|
||||
return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
|
||||
return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
|
||||
)Py";
|
||||
|
||||
/// Template for the variadic group accessor in the single variadic group case:
|
||||
@ -311,7 +314,7 @@ static std::string attrSizedTraitForKind(const char *kind) {
|
||||
/// `operand` or `result` and is used verbatim in the emitted code.
|
||||
static void emitElementAccessors(
|
||||
const Operator &op, raw_ostream &os, const char *kind,
|
||||
llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
|
||||
llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
|
||||
llvm::function_ref<int(const Operator &)> getNumElements,
|
||||
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
||||
getElement) {
|
||||
@ -326,12 +329,12 @@ static void emitElementAccessors(
|
||||
llvm::StringRef(kind).drop_front());
|
||||
std::string attrSizedTrait = attrSizedTraitForKind(kind);
|
||||
|
||||
unsigned numVariadic = getNumVariadic(op);
|
||||
unsigned numVariableLength = getNumVariableLength(op);
|
||||
|
||||
// If there is only one variadic element group, its size can be inferred from
|
||||
// the total number of elements. If there are none, the generation is
|
||||
// straightforward.
|
||||
if (numVariadic <= 1) {
|
||||
// If there is only one variable-length element group, its size can be
|
||||
// inferred from the total number of elements. If there are none, the
|
||||
// generation is straightforward.
|
||||
if (numVariableLength <= 1) {
|
||||
bool seenVariableLength = false;
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
@ -364,7 +367,7 @@ static void emitElementAccessors(
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
if (!element.name.empty()) {
|
||||
os << llvm::formatv(opVariadicEqualPrefixTemplate,
|
||||
sanitizeName(element.name), kind, numVariadic,
|
||||
sanitizeName(element.name), kind, numVariableLength,
|
||||
numPrecedingSimple, numPrecedingVariadic);
|
||||
os << llvm::formatv(element.isVariableLength()
|
||||
? opVariadicEqualVariadicTemplate
|
||||
@ -414,20 +417,20 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
|
||||
|
||||
/// Emits accessors to Op operands.
|
||||
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
|
||||
auto getNumVariadic = [](const Operator &oper) {
|
||||
auto getNumVariableLengthOperands = [](const Operator &oper) {
|
||||
return oper.getNumVariableLengthOperands();
|
||||
};
|
||||
emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
|
||||
getOperand);
|
||||
emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
|
||||
getNumOperands, getOperand);
|
||||
}
|
||||
|
||||
/// Emits accessors Op results.
|
||||
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
|
||||
auto getNumVariadic = [](const Operator &oper) {
|
||||
auto getNumVariableLengthResults = [](const Operator &oper) {
|
||||
return oper.getNumVariableLengthResults();
|
||||
};
|
||||
emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
|
||||
getResult);
|
||||
emitElementAccessors(op, os, "result", getNumVariableLengthResults,
|
||||
getNumResults, getResult);
|
||||
}
|
||||
|
||||
/// Emits accessors to Op attributes.
|
||||
|
Loading…
Reference in New Issue
Block a user