llvm-capstone/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
River Riddle ae40d62541 [mlir] Refactor ElementsAttr's value access API
There are several aspects of the API that either aren't easy to use, or are
deceptively easy to do the wrong thing. The main change of this commit
is to remove all of the `getValue<T>`/`getFlatValue<T>` from ElementsAttr
and instead provide operator[] methods on the ranges returned by
`getValues<T>`. This provides a much more convenient API for the value
ranges. It also removes the easy-to-be-inefficient nature of
getValue/getFlatValue, which under the hood would construct a new range for
the type `T`. Constructing a range is not necessarily cheap in all cases, and
could lead to very poor performance if used within a loop; i.e. if you were to
naively write something like:

```
DenseElementsAttr attr = ...;
for (int i = 0; i < size; ++i) {
  // We are internally rebuilding the APFloat value range on each iteration!!
  APFloat it = attr.getFlatValue<APFloat>(i);
}
```

Differential Revision: https://reviews.llvm.org/D113229
2021-11-09 00:15:08 +00:00

90 lines
3.2 KiB
C++

//===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
using namespace mlir::detail;
//===----------------------------------------------------------------------===//
/// Tablegen Interface Definitions
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
//===----------------------------------------------------------------------===//
// ElementsAttr
//===----------------------------------------------------------------------===//
ShapedType ElementsAttr::getType() const {
return Attribute::getType().cast<ShapedType>();
}
Type ElementsAttr::getElementType(Attribute elementsAttr) {
return elementsAttr.getType().cast<ShapedType>().getElementType();
}
int64_t ElementsAttr::getNumElements(Attribute elementsAttr) {
return elementsAttr.getType().cast<ShapedType>().getNumElements();
}
bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
// Verify that the rank of the indices matches the held type.
int64_t rank = type.getRank();
if (rank == 0 && index.size() == 1 && index[0] == 0)
return true;
if (rank != static_cast<int64_t>(index.size()))
return false;
// Verify that all of the indices are within the shape dimensions.
ArrayRef<int64_t> shape = type.getShape();
return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
int64_t dim = static_cast<int64_t>(index[i]);
return 0 <= dim && dim < shape[i];
});
}
bool ElementsAttr::isValidIndex(Attribute elementsAttr,
ArrayRef<uint64_t> index) {
return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
}
uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
ShapedType shapeType = type.cast<ShapedType>();
assert(isValidIndex(shapeType, index) &&
"expected valid multi-dimensional index");
// Reduce the provided multidimensional index into a flattended 1D row-major
// index.
auto rank = shapeType.getRank();
ArrayRef<int64_t> shape = shapeType.getShape();
uint64_t valueIndex = 0;
uint64_t dimMultiplier = 1;
for (int i = rank - 1; i >= 0; --i) {
valueIndex += index[i] * dimMultiplier;
dimMultiplier *= shape[i];
}
return valueIndex;
}
//===----------------------------------------------------------------------===//
// MemRefLayoutAttrInterface
//===----------------------------------------------------------------------===//
LogicalResult mlir::detail::verifyAffineMapAsLayout(
AffineMap m, ArrayRef<int64_t> shape,
function_ref<InFlightDiagnostic()> emitError) {
if (m.getNumDims() != shape.size())
return emitError() << "memref layout mismatch between rank and affine map: "
<< shape.size() << " != " << m.getNumDims();
return success();
}