From f1e039617b29471e119692e7feaaee74e86dd52f Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 18 Jul 2018 16:29:21 -0700 Subject: [PATCH] Support for AffineMapAttr. PiperOrigin-RevId: 205157390 --- mlir/include/mlir/IR/Attributes.h | 22 ++++++- mlir/include/mlir/IR/Builders.h | 2 + mlir/lib/IR/AsmPrinter.cpp | 97 +++++++++++++++++++++++++++---- mlir/lib/IR/Attributes.cpp | 53 ----------------- mlir/lib/IR/Builders.cpp | 4 ++ mlir/lib/IR/MLIRContext.cpp | 11 ++++ mlir/lib/Parser/Parser.cpp | 5 ++ mlir/test/IR/parser.mlir | 9 +++ 8 files changed, 136 insertions(+), 67 deletions(-) delete mode 100644 mlir/lib/IR/Attributes.cpp diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 2ff129b96500..47d0b9f7b251 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -22,7 +22,8 @@ #include "llvm/ADT/ArrayRef.h" namespace mlir { - class MLIRContext; +class MLIRContext; +class AffineMap; /// Instances of the Attribute class are immutable, uniqued, immortal, and owned /// by MLIRContext. As such, they are passed around by raw non-const pointer. @@ -34,6 +35,7 @@ public: Float, String, Array, + AffineMap, // TODO: Function references. }; @@ -147,7 +149,23 @@ private: ArrayRef value; }; +class AffineMapAttr : public Attribute { +public: + static AffineMapAttr *get(AffineMap *value, MLIRContext *context); + + AffineMap *getValue() const { + return value; + } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const Attribute *attr) { + return attr->getKind() == Kind::AffineMap; + } +private: + AffineMapAttr(AffineMap *value) : Attribute(Kind::AffineMap), value(value) {} + AffineMap *value; +}; + } // end namespace mlir. #endif - diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 9297fd3d380a..c41a886b06bb 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -37,6 +37,7 @@ class IntegerAttr; class FloatAttr; class StringAttr; class ArrayAttr; +class AffineMapAttr; class AffineMap; class AffineExpr; class AffineConstantExpr; @@ -74,6 +75,7 @@ public: FloatAttr *getFloatAttr(double value); StringAttr *getStringAttr(StringRef bytes); ArrayAttr *getArrayAttr(ArrayRef value); + AffineMapAttr *getAffineMapAttr(AffineMap *value); // Affine Expressions and Affine Map. AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount, diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 6a2d9321d0c3..0e5d10148f55 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -50,6 +50,7 @@ public: void initialize(const Module *module); void print(const Module *module); + void print(const Attribute *attr) const; void print(const Type *type) const; void print(const Function *fn); void print(const ExtFunction *fn); @@ -77,6 +78,11 @@ private: void visitCFGFunction(const CFGFunction *fn); void visitMLFunction(const MLFunction *fn); void visitType(const Type *type); + void visitAttribute(const Attribute *attr); + void visitOperation(const Operation *op); + + void printAffineMapId(int affineMapId) const; + void printAffineMapReference(const AffineMap* affineMap) const; raw_ostream &os; DenseMap affineMapIds; @@ -113,6 +119,22 @@ void ModuleState::visitType(const Type *type) { } } +void ModuleState::visitAttribute(const Attribute *attr) { + if (isa(attr)) { + recordAffineMapReference(cast(attr)->getValue()); + } else if (isa(attr)) { + for (auto elt : cast(attr)->getValue()) { + visitAttribute(elt); + } + } +} + +void ModuleState::visitOperation(const Operation *op) { + for (auto elt : op->getAttrs()) { + visitAttribute(elt.second); + } +} + void ModuleState::visitExtFunction(const ExtFunction *fn) { visitType(fn->getType()); } @@ -120,11 +142,16 @@ void ModuleState::visitExtFunction(const ExtFunction *fn) { void ModuleState::visitCFGFunction(const CFGFunction *fn) { visitType(fn->getType()); // TODO Visit function body instructions. + for (auto &block : *fn) { + for (auto &op : block.getOperations()) { + visitOperation(&op); + } + } } void ModuleState::visitMLFunction(const MLFunction *fn) { visitType(fn->getType()); - // TODO Visit function body statements. + // TODO Visit function body statements (and attributes if required). } void ModuleState::visitFunction(const Function *fn) { @@ -151,13 +178,24 @@ void ModuleState::print(const Function *fn) { } // Prints affine map identifier. -static void printAffineMapId(unsigned affineMapId, raw_ostream &os) { +void ModuleState::printAffineMapId(int affineMapId) const { os << "#map" << affineMapId; } +void ModuleState::printAffineMapReference(const AffineMap* affineMap) const { + const int mapId = getAffineMapId(affineMap); + if (mapId >= 0) { + // Map will be printed at top of module so print reference to its id. + printAffineMapId(mapId); + } else { + // Map not in module state so print inline. + affineMap->print(os); + } +} + void ModuleState::print(const Module *module) { for (const auto &mapAndId : affineMapIds) { - printAffineMapId(mapAndId.second, os); + printAffineMapId(mapAndId.second); os << " = "; mapAndId.first->print(os); os << '\n'; @@ -165,6 +203,37 @@ void ModuleState::print(const Module *module) { for (auto *fn : module->functionList) print(fn); } +void ModuleState::print(const Attribute *attr) const { + switch (attr->getKind()) { + case Attribute::Kind::Bool: + os << (cast(attr)->getValue() ? "true" : "false"); + break; + case Attribute::Kind::Integer: + os << cast(attr)->getValue(); + break; + case Attribute::Kind::Float: + // FIXME: this isn't precise, we should print with a hex format. + os << cast(attr)->getValue(); + break; + case Attribute::Kind::String: + // FIXME: should escape the string. + os << '"' << cast(attr)->getValue() << '"'; + break; + case Attribute::Kind::Array: { + auto elts = cast(attr)->getValue(); + os << '['; + interleave(elts, + [&](Attribute *attr) { print(attr); }, + [&]() { os << ", "; }); + os << ']'; + break; + } + case Attribute::Kind::AffineMap: + printAffineMapReference(cast(attr)->getValue()); + break; + } +} + void ModuleState::print(const Type *type) const { switch (type->getKind()) { case Type::Kind::AffineInt: @@ -243,14 +312,7 @@ void ModuleState::print(const Type *type) const { os << *v->getElementType(); for (auto map : v->getAffineMaps()) { os << ", "; - const int mapId = getAffineMapId(map); - if (mapId >= 0) { - // Map will be printed at top of module so print reference to its id. - printAffineMapId(mapId, os); - } else { - // Map not in module state so print inline. - map->print(os); - } + printAffineMapReference(map); } os << ", " << v->getMemorySpace(); os << '>'; @@ -338,7 +400,9 @@ void FunctionState::printOperation(const Operation *op) { os << '{'; interleave( attrs, - [&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; }, + [&](NamedAttribute attr) { + os << attr.first << ": "; + moduleState->print(attr.second); }, [&]() { os << ", "; }); os << '}'; } @@ -553,6 +617,15 @@ void ModuleState::print(const MLFunction *fn) { // print and dump methods //===----------------------------------------------------------------------===// +void Attribute::print(raw_ostream &os) const { + ModuleState moduleState(os); + moduleState.print(this); +} + +void Attribute::dump() const { + print(llvm::errs()); +} + void Type::print(raw_ostream &os) const { ModuleState moduleState(os); moduleState.print(this); diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp deleted file mode 100644 index df23424ed25d..000000000000 --- a/mlir/lib/IR/Attributes.cpp +++ /dev/null @@ -1,53 +0,0 @@ -//===- Attributes.cpp - MLIR Attribute Implementation ---------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/IR/Attributes.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Support/STLExtras.h" -using namespace mlir; - -void Attribute::print(raw_ostream &os) const { - switch (getKind()) { - case Kind::Bool: - os << (cast(this)->getValue() ? "true" : "false"); - break; - case Kind::Integer: - os << cast(this)->getValue(); - break; - case Kind::Float: - // FIXME: this isn't precise, we should print with a hex format. - os << cast(this)->getValue(); - break; - case Kind::String: - // FIXME: should escape the string. - os << '"' << cast(this)->getValue() << '"'; - break; - case Kind::Array: { - auto elts = cast(this)->getValue(); - os << '['; - interleave(elts, - [&](Attribute *attr) { attr->print(os); }, - [&]() { os << ", "; }); - os << ']'; - break; - } - } -} - -void Attribute::dump() const { - print(llvm::errs()); -} diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index e9bea2a73245..8d27991b8fd0 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -94,6 +94,10 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef value) { return ArrayAttr::get(value, context); } +AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) { + return AffineMapAttr::get(value, context); +} + //===----------------------------------------------------------------------===// // Affine Expressions and Affine Map. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index eb06e949ea59..df3d01aad8d3 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -230,6 +230,7 @@ public: StringMap stringAttrs; using ArrayAttrSet = DenseSet; ArrayAttrSet arrayAttrs; + DenseMap affineMapAttrs; using AttributeListSet = DenseSet; AttributeListSet attributeLists; @@ -541,6 +542,16 @@ ArrayAttr *ArrayAttr::get(ArrayRef value, MLIRContext *context) { return *existing.first = result; } +AffineMapAttr *AffineMapAttr::get(AffineMap* value, MLIRContext *context) { + auto *&result = context->getImpl().affineMapAttrs[value]; + if (result) + return result; + + result = context->getImpl().allocator.Allocate(); + new (result) AffineMapAttr(value); + return result; +} + /// Perform a three-way comparison between the names of the specified /// NamedAttributes. static int compareNamedAttributes(const NamedAttribute *lhs, diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 3d5c9088a403..d220787d616d 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -598,6 +598,11 @@ Attribute *Parser::parseAttribute() { return builder.getArrayAttr(elements); } default: + // Try to parse affine map reference. + auto* affineMap = parseAffineMapReference(); + if (affineMap != nullptr) + return builder.getAffineMapAttr(affineMap); + // TODO: Handle floating point. return (emitError("expected constant attribute value"), nullptr); } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 6e0be249ef95..7a1cfce1fc22 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -136,6 +136,15 @@ bb42: // CHECK: bb0: // CHECK: "foo"(){a: 1, b: -423, c: [true, false]} : () -> () "foo"(){a: 1, b: -423, c: [true, false] } : () -> () + // CHECK: "foo"(){map1: #map{{[0-9]+}}} + "foo"(){map1: #map1} : () -> () + + // CHECK: "foo"(){map2: #map{{[0-9]+}}} + "foo"(){map2: (d0, d1, d2) -> (d0, d1, d2)} : () -> () + + // CHECK: "foo"(){map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]} + "foo"(){map12: [#map1, #map2]} : () -> () + // CHECK: "foo"(){cfgfunc: [], i123: 7, if: "foo"} : () -> () "foo"(){if: "foo", cfgfunc: [], i123: 7} : () -> ()