Support for AffineMapAttr.

PiperOrigin-RevId: 205157390
This commit is contained in:
MLIR Team 2018-07-18 16:29:21 -07:00 committed by jpienaar
parent b3fa7d0e9f
commit f1e039617b
8 changed files with 136 additions and 67 deletions

View File

@ -22,7 +22,8 @@
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class MLIRContext;
class MLIRContext;
class AffineMap;
/// Instances of the Attribute class are immutable, uniqued, immortal, and owned
/// by MLIRContext. As such, they are passed around by raw non-const pointer.
@ -34,6 +35,7 @@ public:
Float,
String,
Array,
AffineMap,
// TODO: Function references.
};
@ -147,7 +149,23 @@ private:
ArrayRef<Attribute*> value;
};
class AffineMapAttr : public Attribute {
public:
static AffineMapAttr *get(AffineMap *value, MLIRContext *context);
AffineMap *getValue() const {
return value;
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::AffineMap;
}
private:
AffineMapAttr(AffineMap *value) : Attribute(Kind::AffineMap), value(value) {}
AffineMap *value;
};
} // end namespace mlir.
#endif

View File

@ -37,6 +37,7 @@ class IntegerAttr;
class FloatAttr;
class StringAttr;
class ArrayAttr;
class AffineMapAttr;
class AffineMap;
class AffineExpr;
class AffineConstantExpr;
@ -74,6 +75,7 @@ public:
FloatAttr *getFloatAttr(double value);
StringAttr *getStringAttr(StringRef bytes);
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
AffineMapAttr *getAffineMapAttr(AffineMap *value);
// Affine Expressions and Affine Map.
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,

View File

@ -50,6 +50,7 @@ public:
void initialize(const Module *module);
void print(const Module *module);
void print(const Attribute *attr) const;
void print(const Type *type) const;
void print(const Function *fn);
void print(const ExtFunction *fn);
@ -77,6 +78,11 @@ private:
void visitCFGFunction(const CFGFunction *fn);
void visitMLFunction(const MLFunction *fn);
void visitType(const Type *type);
void visitAttribute(const Attribute *attr);
void visitOperation(const Operation *op);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(const AffineMap* affineMap) const;
raw_ostream &os;
DenseMap<const AffineMap *, int> affineMapIds;
@ -113,6 +119,22 @@ void ModuleState::visitType(const Type *type) {
}
}
void ModuleState::visitAttribute(const Attribute *attr) {
if (isa<AffineMapAttr>(attr)) {
recordAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
} else if (isa<ArrayAttr>(attr)) {
for (auto elt : cast<ArrayAttr>(attr)->getValue()) {
visitAttribute(elt);
}
}
}
void ModuleState::visitOperation(const Operation *op) {
for (auto elt : op->getAttrs()) {
visitAttribute(elt.second);
}
}
void ModuleState::visitExtFunction(const ExtFunction *fn) {
visitType(fn->getType());
}
@ -120,11 +142,16 @@ void ModuleState::visitExtFunction(const ExtFunction *fn) {
void ModuleState::visitCFGFunction(const CFGFunction *fn) {
visitType(fn->getType());
// TODO Visit function body instructions.
for (auto &block : *fn) {
for (auto &op : block.getOperations()) {
visitOperation(&op);
}
}
}
void ModuleState::visitMLFunction(const MLFunction *fn) {
visitType(fn->getType());
// TODO Visit function body statements.
// TODO Visit function body statements (and attributes if required).
}
void ModuleState::visitFunction(const Function *fn) {
@ -151,13 +178,24 @@ void ModuleState::print(const Function *fn) {
}
// Prints affine map identifier.
static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
void ModuleState::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
}
void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
const int mapId = getAffineMapId(affineMap);
if (mapId >= 0) {
// Map will be printed at top of module so print reference to its id.
printAffineMapId(mapId);
} else {
// Map not in module state so print inline.
affineMap->print(os);
}
}
void ModuleState::print(const Module *module) {
for (const auto &mapAndId : affineMapIds) {
printAffineMapId(mapAndId.second, os);
printAffineMapId(mapAndId.second);
os << " = ";
mapAndId.first->print(os);
os << '\n';
@ -165,6 +203,37 @@ void ModuleState::print(const Module *module) {
for (auto *fn : module->functionList) print(fn);
}
void ModuleState::print(const Attribute *attr) const {
switch (attr->getKind()) {
case Attribute::Kind::Bool:
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
break;
case Attribute::Kind::Integer:
os << cast<IntegerAttr>(attr)->getValue();
break;
case Attribute::Kind::Float:
// FIXME: this isn't precise, we should print with a hex format.
os << cast<FloatAttr>(attr)->getValue();
break;
case Attribute::Kind::String:
// FIXME: should escape the string.
os << '"' << cast<StringAttr>(attr)->getValue() << '"';
break;
case Attribute::Kind::Array: {
auto elts = cast<ArrayAttr>(attr)->getValue();
os << '[';
interleave(elts,
[&](Attribute *attr) { print(attr); },
[&]() { os << ", "; });
os << ']';
break;
}
case Attribute::Kind::AffineMap:
printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
break;
}
}
void ModuleState::print(const Type *type) const {
switch (type->getKind()) {
case Type::Kind::AffineInt:
@ -243,14 +312,7 @@ void ModuleState::print(const Type *type) const {
os << *v->getElementType();
for (auto map : v->getAffineMaps()) {
os << ", ";
const int mapId = getAffineMapId(map);
if (mapId >= 0) {
// Map will be printed at top of module so print reference to its id.
printAffineMapId(mapId, os);
} else {
// Map not in module state so print inline.
map->print(os);
}
printAffineMapReference(map);
}
os << ", " << v->getMemorySpace();
os << '>';
@ -338,7 +400,9 @@ void FunctionState::printOperation(const Operation *op) {
os << '{';
interleave(
attrs,
[&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
[&](NamedAttribute attr) {
os << attr.first << ": ";
moduleState->print(attr.second); },
[&]() { os << ", "; });
os << '}';
}
@ -553,6 +617,15 @@ void ModuleState::print(const MLFunction *fn) {
// print and dump methods
//===----------------------------------------------------------------------===//
void Attribute::print(raw_ostream &os) const {
ModuleState moduleState(os);
moduleState.print(this);
}
void Attribute::dump() const {
print(llvm::errs());
}
void Type::print(raw_ostream &os) const {
ModuleState moduleState(os);
moduleState.print(this);

View File

@ -1,53 +0,0 @@
//===- Attributes.cpp - MLIR Attribute Implementation ---------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/Attributes.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Support/STLExtras.h"
using namespace mlir;
void Attribute::print(raw_ostream &os) const {
switch (getKind()) {
case Kind::Bool:
os << (cast<BoolAttr>(this)->getValue() ? "true" : "false");
break;
case Kind::Integer:
os << cast<IntegerAttr>(this)->getValue();
break;
case Kind::Float:
// FIXME: this isn't precise, we should print with a hex format.
os << cast<FloatAttr>(this)->getValue();
break;
case Kind::String:
// FIXME: should escape the string.
os << '"' << cast<StringAttr>(this)->getValue() << '"';
break;
case Kind::Array: {
auto elts = cast<ArrayAttr>(this)->getValue();
os << '[';
interleave(elts,
[&](Attribute *attr) { attr->print(os); },
[&]() { os << ", "; });
os << ']';
break;
}
}
}
void Attribute::dump() const {
print(llvm::errs());
}

View File

@ -94,6 +94,10 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
return ArrayAttr::get(value, context);
}
AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
return AffineMapAttr::get(value, context);
}
//===----------------------------------------------------------------------===//
// Affine Expressions and Affine Map.
//===----------------------------------------------------------------------===//

View File

@ -230,6 +230,7 @@ public:
StringMap<StringAttr*> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttr*, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs;
DenseMap<AffineMap*, AffineMapAttr*> affineMapAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
@ -541,6 +542,16 @@ ArrayAttr *ArrayAttr::get(ArrayRef<Attribute*> value, MLIRContext *context) {
return *existing.first = result;
}
AffineMapAttr *AffineMapAttr::get(AffineMap* value, MLIRContext *context) {
auto *&result = context->getImpl().affineMapAttrs[value];
if (result)
return result;
result = context->getImpl().allocator.Allocate<AffineMapAttr>();
new (result) AffineMapAttr(value);
return result;
}
/// Perform a three-way comparison between the names of the specified
/// NamedAttributes.
static int compareNamedAttributes(const NamedAttribute *lhs,

View File

@ -598,6 +598,11 @@ Attribute *Parser::parseAttribute() {
return builder.getArrayAttr(elements);
}
default:
// Try to parse affine map reference.
auto* affineMap = parseAffineMapReference();
if (affineMap != nullptr)
return builder.getAffineMapAttr(affineMap);
// TODO: Handle floating point.
return (emitError("expected constant attribute value"), nullptr);
}

View File

@ -136,6 +136,15 @@ bb42: // CHECK: bb0:
// CHECK: "foo"(){a: 1, b: -423, c: [true, false]} : () -> ()
"foo"(){a: 1, b: -423, c: [true, false] } : () -> ()
// CHECK: "foo"(){map1: #map{{[0-9]+}}}
"foo"(){map1: #map1} : () -> ()
// CHECK: "foo"(){map2: #map{{[0-9]+}}}
"foo"(){map2: (d0, d1, d2) -> (d0, d1, d2)} : () -> ()
// CHECK: "foo"(){map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]}
"foo"(){map12: [#map1, #map2]} : () -> ()
// CHECK: "foo"(){cfgfunc: [], i123: 7, if: "foo"} : () -> ()
"foo"(){if: "foo", cfgfunc: [], i123: 7} : () -> ()