[mlir][irdl] Add irdl.base op (#76400)

The `irdl.base` op represent an attribute constraint that will check
that the
base of a type or attribute is the expected one (e.g. `IntegerType`) .

Example:

```mlir
irdl.dialect @cmath {
  irdl.type @complex {
    %0 = irdl.base "!builtin.integer"
    irdl.parameters(%0)
  }

  irdl.type @complex_wrapper {
    %0 = irdl.base @complex
    irdl.parameters(%0)
  }
}
```

The above program defines a `cmath.complex` type that expects a single
parameter, which is a type with base name `builtin.integer`, which is
the
name of an `IntegerType` type.
It also defines a `cmath.complex_wrapper` type that expects a single
parameter, which is a type of base type `cmath.complex`.
This commit is contained in:
Fehr Mathieu 2024-01-18 16:31:40 +00:00 committed by GitHub
parent d124b02242
commit 914cfa4138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 350 additions and 17 deletions

View File

@ -451,6 +451,57 @@ def IRDL_IsOp : IRDL_ConstraintOp<"is",
let assemblyFormat = " $expected ` ` attr-dict ";
}
def IRDL_BaseOp : IRDL_ConstraintOp<"base",
[ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Constraints an attribute/type base";
let description = [{
`irdl.base` defines a constraint that only accepts a single type
or attribute base, e.g. an `IntegerType`. The attribute base is defined
either by a symbolic reference to the corresponding IRDL definition,
or by the name of the base. Named bases are prefixed with `!` or `#`
respectively for types and attributes.
Example:
```mlir
irdl.dialect @cmath {
irdl.type @complex {
%0 = irdl.base "!builtin.integer"
irdl.parameters(%0)
}
irdl.type @complex_wrapper {
%0 = irdl.base @complex
irdl.parameters(%0)
}
}
```
The above program defines a `cmath.complex` type that expects a single
parameter, which is a type with base name `builtin.integer`, which is the
name of an `IntegerType` type.
It also defines a `cmath.complex_wrapper` type that expects a single
parameter, which is a type of base type `cmath.complex`.
}];
let arguments = (ins OptionalAttr<SymbolRefAttr>:$base_ref,
OptionalAttr<StrAttr>:$base_name);
let results = (outs IRDL_AttributeType:$output);
let assemblyFormat = " ($base_ref^)? ($base_name^)? ` ` attr-dict";
let builders = [
OpBuilder<(ins "SymbolRefAttr":$base_ref), [{
build($_builder, $_state, base_ref, {});
}]>,
OpBuilder<(ins "StringAttr":$base_name), [{
build($_builder, $_state, {}, base_name);
}]>,
];
let hasVerifier = 1;
}
def IRDL_ParametricOp : IRDL_ConstraintOp<"parametric",
[ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>, Pure]> {
let summary = "Constraints an attribute/type base and its parameters";

View File

@ -99,6 +99,48 @@ private:
Attribute expectedAttribute;
};
/// A constraint that checks that an attribute is of a given attribute base
/// (e.g. IntegerAttr).
class BaseAttrConstraint : public Constraint {
public:
BaseAttrConstraint(TypeID baseTypeID, StringRef baseName)
: baseTypeID(baseTypeID), baseName(baseName) {}
virtual ~BaseAttrConstraint() = default;
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Attribute attr,
ConstraintVerifier &context) const override;
private:
/// The expected base attribute typeID.
TypeID baseTypeID;
/// The base attribute name, only used for error reporting.
StringRef baseName;
};
/// A constraint that checks that a type is of a given type base (e.g.
/// IntegerType).
class BaseTypeConstraint : public Constraint {
public:
BaseTypeConstraint(TypeID baseTypeID, StringRef baseName)
: baseTypeID(baseTypeID), baseName(baseName) {}
virtual ~BaseTypeConstraint() = default;
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Attribute attr,
ConstraintVerifier &context) const override;
private:
/// The expected base type typeID.
TypeID baseTypeID;
/// The base type name, only used for error reporting.
StringRef baseName;
};
/// A constraint that checks that an attribute is of a
/// specific dynamic attribute definition, and that all of its parameters
/// satisfy the given constraints.

View File

@ -117,6 +117,39 @@ LogicalResult AttributesOp::verify() {
return success();
}
LogicalResult BaseOp::verify() {
std::optional<StringRef> baseName = getBaseName();
std::optional<SymbolRefAttr> baseRef = getBaseRef();
if (baseName.has_value() == baseRef.has_value())
return emitOpError() << "the base type or attribute should be specified by "
"either a name or a reference";
if (baseName &&
(baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
return emitOpError() << "the base type or attribute name should start with "
"'!' or '#'";
return success();
}
LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
std::optional<SymbolRefAttr> baseRef = getBaseRef();
if (!baseRef)
return success();
TypeOp typeOp = symbolTable.lookupNearestSymbolFrom<TypeOp>(*this, *baseRef);
if (typeOp)
return success();
AttributeOp attrOp =
symbolTable.lookupNearestSymbolFrom<AttributeOp>(*this, *baseRef);
if (attrOp)
return success();
return emitOpError() << "'" << *baseRef
<< "' does not refer to a type or attribute definition";
}
/// Parse a value with its variadicity first. By default, the variadicity is
/// single.
///

View File

@ -37,6 +37,60 @@ std::unique_ptr<Constraint> IsOp::getVerifier(
return std::make_unique<IsConstraint>(getExpectedAttr());
}
std::unique_ptr<Constraint> BaseOp::getVerifier(
ArrayRef<Value> valueToConstr,
DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
&attrs) {
MLIRContext *ctx = getContext();
// Case where the input is a symbol reference.
// This corresponds to the case where the base is an IRDL type or attribute.
if (auto baseRef = getBaseRef()) {
Operation *defOp =
SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
// Type case.
if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
DynamicTypeDefinition *typeDef = types.at(typeOp).get();
auto name = StringAttr::get(ctx, typeDef->getDialect()->getNamespace() +
"." + typeDef->getName().str());
return std::make_unique<BaseTypeConstraint>(typeDef->getTypeID(), name);
}
// Attribute case.
auto attrOp = cast<AttributeOp>(defOp);
DynamicAttrDefinition *attrDef = attrs.at(attrOp).get();
auto name = StringAttr::get(ctx, attrDef->getDialect()->getNamespace() +
"." + attrDef->getName().str());
return std::make_unique<BaseAttrConstraint>(attrDef->getTypeID(), name);
}
// Case where the input is string literal.
// This corresponds to the case where the base is a registered type or
// attribute.
StringRef baseName = getBaseName().value();
// Type case.
if (baseName[0] == '!') {
auto abstractType = AbstractType::lookup(baseName.drop_front(1), ctx);
if (!abstractType) {
emitError() << "no registered type with name " << baseName;
return nullptr;
}
return std::make_unique<BaseTypeConstraint>(abstractType->get().getTypeID(),
abstractType->get().getName());
}
auto abstractAttr = AbstractAttribute::lookup(baseName.drop_front(1), ctx);
if (!abstractAttr) {
emitError() << "no registered attribute with name " << baseName;
return nullptr;
}
return std::make_unique<BaseAttrConstraint>(abstractAttr->get().getTypeID(),
abstractAttr->get().getName());
}
std::unique_ptr<Constraint> ParametricOp::getVerifier(
ArrayRef<Value> valueToConstr,
DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,

View File

@ -68,6 +68,39 @@ LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
return failure();
}
LogicalResult
BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
Attribute attr, ConstraintVerifier &context) const {
if (attr.getTypeID() == baseTypeID)
return success();
if (emitError)
return emitError() << "expected base attribute '" << baseName
<< "' but got '" << attr.getAbstractAttribute().getName()
<< "'";
return failure();
}
LogicalResult
BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
Attribute attr, ConstraintVerifier &context) const {
auto typeAttr = dyn_cast<TypeAttr>(attr);
if (!typeAttr) {
if (emitError)
return emitError() << "expected type, got attribute '" << attr;
return failure();
}
Type type = typeAttr.getValue();
if (type.getTypeID() == baseTypeID)
return success();
if (emitError)
return emitError() << "expected base type '" << baseName << "' but got '"
<< type.getAbstractType().getName() << "'";
return failure();
}
LogicalResult DynParametricAttrConstraint::verify(
function_ref<InFlightDiagnostic()> emitError, Attribute attr,
ConstraintVerifier &context) const {

View File

@ -0,0 +1,43 @@
// RUN: mlir-opt %s -verify-diagnostics -split-input-file
// Testing invalid IRDL IRs
func.func private @foo()
irdl.dialect @testd {
irdl.type @type {
// expected-error@+1 {{'@foo' does not refer to a type or attribute definition}}
%0 = irdl.base @foo
irdl.parameters(%0)
}
}
// -----
irdl.dialect @testd {
irdl.type @type {
// expected-error@+1 {{the base type or attribute name should start with '!' or '#'}}
%0 = irdl.base "builtin.integer"
irdl.parameters(%0)
}
}
// -----
irdl.dialect @testd {
irdl.type @type {
// expected-error@+1 {{the base type or attribute name should start with '!' or '#'}}
%0 = irdl.base ""
irdl.parameters(%0)
}
}
// -----
irdl.dialect @testd {
irdl.type @type {
// expected-error@+1 {{the base type or attribute should be specified by either a name}}
%0 = irdl.base
irdl.parameters(%0)
}
}

View File

@ -11,6 +11,15 @@ irdl.dialect @testd {
irdl.parameters(%0)
}
// CHECK: irdl.attribute @parametric_attr {
// CHECK: %[[v0:[^ ]*]] = irdl.any
// CHECK: irdl.parameters(%[[v0]])
// CHECK: }
irdl.attribute @parametric_attr {
%0 = irdl.any
irdl.parameters(%0)
}
// CHECK: irdl.type @attr_in_type_out {
// CHECK: %[[v0:[^ ]*]] = irdl.any
// CHECK: irdl.parameters(%[[v0]])
@ -66,15 +75,40 @@ irdl.dialect @testd {
irdl.results(%0)
}
// CHECK: irdl.operation @dynbase {
// CHECK: %[[v0:[^ ]*]] = irdl.any
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @parametric<%[[v0]]>
// CHECK: irdl.operation @dyn_type_base {
// CHECK: %[[v1:[^ ]*]] = irdl.base @parametric
// CHECK: irdl.results(%[[v1]])
// CHECK: }
irdl.operation @dynbase {
%0 = irdl.any
%1 = irdl.parametric @parametric<%0>
irdl.results(%1)
irdl.operation @dyn_type_base {
%0 = irdl.base @parametric
irdl.results(%0)
}
// CHECK: irdl.operation @dyn_attr_base {
// CHECK: %[[v1:[^ ]*]] = irdl.base @parametric_attr
// CHECK: irdl.attributes {"attr1" = %[[v1]]}
// CHECK: }
irdl.operation @dyn_attr_base {
%0 = irdl.base @parametric_attr
irdl.attributes {"attr1" = %0}
}
// CHECK: irdl.operation @named_type_base {
// CHECK: %[[v1:[^ ]*]] = irdl.base "!builtin.integer"
// CHECK: irdl.results(%[[v1]])
// CHECK: }
irdl.operation @named_type_base {
%0 = irdl.base "!builtin.integer"
irdl.results(%0)
}
// CHECK: irdl.operation @named_attr_base {
// CHECK: %[[v1:[^ ]*]] = irdl.base "#builtin.integer"
// CHECK: irdl.attributes {"attr1" = %[[v1]]}
// CHECK: }
irdl.operation @named_attr_base {
%0 = irdl.base "#builtin.integer"
irdl.attributes {"attr1" = %0}
}
// CHECK: irdl.operation @dynparams {

View File

@ -120,24 +120,67 @@ func.func @succeededAnyConstraint() {
// -----
//===----------------------------------------------------------------------===//
// Dynamic base constraint
// Base constraints
//===----------------------------------------------------------------------===//
func.func @succeededDynBaseConstraint() {
// CHECK: "testd.dynbase"() : () -> !testd.parametric<i32>
"testd.dynbase"() : () -> !testd.parametric<i32>
// CHECK: "testd.dynbase"() : () -> !testd.parametric<i64>
"testd.dynbase"() : () -> !testd.parametric<i64>
// CHECK: "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
"testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
// CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<i32>
"testd.dyn_type_base"() : () -> !testd.parametric<i32>
// CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<i64>
"testd.dyn_type_base"() : () -> !testd.parametric<i64>
// CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<!testd.parametric<i64>>
"testd.dyn_type_base"() : () -> !testd.parametric<!testd.parametric<i64>>
// CHECK: "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i32>} : () -> ()
"testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i32>} : () -> ()
// CHECK: "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i64>} : () -> ()
"testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i64>} : () -> ()
return
}
// -----
func.func @failedDynBaseConstraint() {
// expected-error@+1 {{expected base type 'testd.parametric' but got 'i32'}}
"testd.dynbase"() : () -> i32
func.func @failedDynTypeBaseConstraint() {
// expected-error@+1 {{expected base type 'testd.parametric' but got 'builtin.integer'}}
"testd.dyn_type_base"() : () -> i32
return
}
// -----
func.func @failedDynAttrBaseConstraintNotType() {
// expected-error@+1 {{expected base attribute 'testd.parametric_attr' but got 'builtin.type'}}
"testd.dyn_attr_base"() {attr1 = i32}: () -> ()
return
}
// -----
func.func @succeededNamedBaseConstraint() {
// CHECK: "testd.named_type_base"() : () -> i32
"testd.named_type_base"() : () -> i32
// CHECK: "testd.named_type_base"() : () -> i64
"testd.named_type_base"() : () -> i64
// CHECK: "testd.named_attr_base"() {attr1 = 0 : i32} : () -> ()
"testd.named_attr_base"() {attr1 = 0 : i32} : () -> ()
// CHECK: "testd.named_attr_base"() {attr1 = 0 : i64} : () -> ()
"testd.named_attr_base"() {attr1 = 0 : i64} : () -> ()
return
}
// -----
func.func @failedNamedTypeBaseConstraint() {
// expected-error@+1 {{expected base type 'builtin.integer' but got 'builtin.vector'}}
"testd.named_type_base"() : () -> vector<i32>
return
}
// -----
func.func @failedDynAttrBaseConstraintNotType() {
// expected-error@+1 {{expected base attribute 'builtin.integer' but got 'builtin.type'}}
"testd.named_attr_base"() {attr1 = i32}: () -> ()
return
}