mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-29 16:41:27 +00:00
Revert "Add convenience C++ helper to manipulate ranked strided memref"
This reverts commit 11f32a41c2
.
The build is broken because this commit conflits with the refactoring of
the DialectRegistry APIs in the context. It'll reland shortly after
fixing the API usage.
This commit is contained in:
parent
0c254b4a69
commit
e49967fbd9
@ -31,15 +31,11 @@
|
||||
#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
|
||||
#endif // _WIN32
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Codegen-compatible structures for Vector type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
|
||||
constexpr bool isPowerOf2(int N) { return (!(N & (N - 1))); }
|
||||
@ -69,8 +65,9 @@ private:
|
||||
template <typename T, int Dim>
|
||||
struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
|
||||
Vector1D() {
|
||||
static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error");
|
||||
static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
|
||||
static_assert(detail::nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]),
|
||||
"size error");
|
||||
static_assert(detail::nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
|
||||
"size error");
|
||||
}
|
||||
inline T &operator[](unsigned i) { return vector[i]; }
|
||||
@ -78,10 +75,9 @@ struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
|
||||
|
||||
private:
|
||||
T vector[Dim];
|
||||
char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])];
|
||||
char padding[detail::nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])];
|
||||
};
|
||||
} // end namespace detail
|
||||
} // end namespace mlir
|
||||
|
||||
// N-D vectors recurse down to 1-D.
|
||||
template <typename T, int Dim, int... Dims>
|
||||
@ -99,9 +95,7 @@ private:
|
||||
// We insert explicit padding in to account for this.
|
||||
template <typename T, int Dim>
|
||||
struct Vector<T, Dim>
|
||||
: public mlir::detail::Vector1D<T, Dim,
|
||||
mlir::detail::isPowerOf2(sizeof(T[Dim]))> {
|
||||
};
|
||||
: public detail::Vector1D<T, Dim, detail::isPowerOf2(sizeof(T[Dim]))> {};
|
||||
|
||||
template <int D1, typename T>
|
||||
using Vector1D = Vector<T, D1>;
|
||||
@ -121,9 +115,6 @@ void dropFront(int64_t arr[N], int64_t *res) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Codegen-compatible structures for StridedMemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename T, int Rank>
|
||||
class StridedMemrefIterator;
|
||||
|
||||
/// StridedMemRef descriptor type with static rank.
|
||||
template <typename T, int N>
|
||||
struct StridedMemRefType {
|
||||
@ -132,23 +123,6 @@ struct StridedMemRefType {
|
||||
int64_t offset;
|
||||
int64_t sizes[N];
|
||||
int64_t strides[N];
|
||||
|
||||
template <typename Range>
|
||||
T &operator[](Range indices) {
|
||||
assert(indices.size() == N &&
|
||||
"indices should match rank in memref subscript");
|
||||
int64_t curOffset = offset;
|
||||
for (int dim = N - 1; dim >= 0; --dim) {
|
||||
int64_t currentIndex = *(indices.begin() + dim);
|
||||
assert(currentIndex < sizes[dim] && "Index overflow");
|
||||
curOffset += currentIndex * strides[dim];
|
||||
}
|
||||
return data[curOffset];
|
||||
}
|
||||
|
||||
StridedMemrefIterator<T, N> begin() { return {*this}; }
|
||||
StridedMemrefIterator<T, N> end() { return {*this, -1}; }
|
||||
|
||||
// This operator[] is extremely slow and only for sugaring purposes.
|
||||
StridedMemRefType<T, N - 1> operator[](int64_t idx) {
|
||||
StridedMemRefType<T, N - 1> res;
|
||||
@ -169,17 +143,6 @@ struct StridedMemRefType<T, 1> {
|
||||
int64_t offset;
|
||||
int64_t sizes[1];
|
||||
int64_t strides[1];
|
||||
|
||||
template <typename Range>
|
||||
T &operator[](Range indices) {
|
||||
assert(indices.size() == 1 &&
|
||||
"indices should match rank in memref subscript");
|
||||
return (*this)[*indices.begin()];
|
||||
}
|
||||
|
||||
StridedMemrefIterator<T, 1> begin() { return {*this}; }
|
||||
StridedMemrefIterator<T, 1> end() { return {*this, -1}; }
|
||||
|
||||
T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
|
||||
};
|
||||
|
||||
@ -189,99 +152,6 @@ struct StridedMemRefType<T, 0> {
|
||||
T *basePtr;
|
||||
T *data;
|
||||
int64_t offset;
|
||||
|
||||
template <typename Range>
|
||||
T &operator[](Range indices) {
|
||||
assert(indices.empty() &&
|
||||
"Expect empty indices for 0-rank memref subscript");
|
||||
return data[offset];
|
||||
}
|
||||
|
||||
StridedMemrefIterator<T, 0> begin() { return {*this}; }
|
||||
StridedMemrefIterator<T, 0> end() { return {*this, 1}; }
|
||||
};
|
||||
|
||||
/// Iterate over all elements in a strided memref.
|
||||
template <typename T, int Rank>
|
||||
class StridedMemrefIterator {
|
||||
public:
|
||||
StridedMemrefIterator(StridedMemRefType<T, Rank> &descriptor,
|
||||
int64_t offset = 0)
|
||||
: offset(offset), descriptor(descriptor) {}
|
||||
StridedMemrefIterator<T, Rank> &operator++() {
|
||||
int dim = Rank - 1;
|
||||
while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) {
|
||||
offset -= indices[dim] * descriptor.strides[dim];
|
||||
indices[dim] = 0;
|
||||
--dim;
|
||||
}
|
||||
if (dim < 0) {
|
||||
offset = -1;
|
||||
return *this;
|
||||
}
|
||||
++indices[dim];
|
||||
offset += descriptor.strides[dim];
|
||||
return *this;
|
||||
}
|
||||
|
||||
T &operator*() { return descriptor.data[offset]; }
|
||||
T *operator->() { return &descriptor.data[offset]; }
|
||||
|
||||
const std::array<int64_t, Rank> &getIndices() { return indices; }
|
||||
|
||||
bool operator==(const StridedMemrefIterator &other) const {
|
||||
return other.offset == offset && &other.descriptor == &descriptor;
|
||||
}
|
||||
|
||||
bool operator!=(const StridedMemrefIterator &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Offset in the buffer. This can be derived from the indices and the
|
||||
/// descriptor.
|
||||
int64_t offset = 0;
|
||||
/// Array of indices in the multi-dimensional memref.
|
||||
std::array<int64_t, Rank> indices = {};
|
||||
/// Descriptor for the strided memref.
|
||||
StridedMemRefType<T, Rank> &descriptor;
|
||||
};
|
||||
|
||||
/// Iterate over all elements in a 0-ranked strided memref.
|
||||
template <typename T>
|
||||
class StridedMemrefIterator<T, 0> {
|
||||
public:
|
||||
StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
|
||||
: elt(descriptor.data + offset) {}
|
||||
|
||||
StridedMemrefIterator<T, 0> &operator++() {
|
||||
++elt;
|
||||
return *this;
|
||||
}
|
||||
|
||||
T &operator*() { return *elt; }
|
||||
T *operator->() { return elt; }
|
||||
|
||||
// There are no indices for a 0-ranked memref, but this API is provided for
|
||||
// consistency with the general case.
|
||||
const std::array<int64_t, 0> &getIndices() {
|
||||
// Since this is a 0-array of indices we can keep a single global const
|
||||
// copy.
|
||||
static const std::array<int64_t, 0> indices = {};
|
||||
return indices;
|
||||
}
|
||||
|
||||
bool operator==(const StridedMemrefIterator &other) const {
|
||||
return other.elt == elt;
|
||||
}
|
||||
|
||||
bool operator!=(const StridedMemrefIterator &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Pointer to the single element in the zero-ranked memref.
|
||||
T *elt;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1,214 +0,0 @@
|
||||
//===- MemRefUtils.h - Memref helpers to invoke MLIR JIT code ---*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Utils for MLIR ABI interfacing with frameworks.
|
||||
//
|
||||
// The templated free functions below make it possible to allocate dense
|
||||
// contiguous buffers with shapes that interoperate properly with the MLIR
|
||||
// codegen ABI.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/ExecutionEngine/CRunnerUtils.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
|
||||
#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
|
||||
#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
|
||||
|
||||
namespace mlir {
|
||||
using AllocFunType = llvm::function_ref<void *(size_t)>;
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Given a shape with sizes greater than 0 along all dimensions, returns the
|
||||
/// distance, in number of elements, between a slice in a dimension and the next
|
||||
/// slice in the same dimension.
|
||||
/// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
|
||||
template <size_t N>
|
||||
inline std::array<int64_t, N> makeStrides(ArrayRef<int64_t> shape) {
|
||||
assert(shape.size() == N && "expect shape specification to match rank");
|
||||
std::array<int64_t, N> res;
|
||||
int64_t running = 1;
|
||||
for (int64_t idx = N - 1; idx >= 0; --idx) {
|
||||
assert(shape[idx] && "size must be non-negative for all shape dimensions");
|
||||
res[idx] = running;
|
||||
running *= shape[idx];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Build a `StridedMemRefDescriptor<T, N>` that matches the MLIR ABI.
|
||||
/// This is an implementation detail that is kept in sync with MLIR codegen
|
||||
/// conventions. Additionally takes a `shapeAlloc` array which
|
||||
/// is used instead of `shape` to allocate "more aligned" data and compute the
|
||||
/// corresponding strides.
|
||||
template <int N, typename T>
|
||||
typename std::enable_if<(N >= 1), StridedMemRefType<T, N>>::type
|
||||
makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> shapeAlloc) {
|
||||
assert(shape.size() == N);
|
||||
assert(shapeAlloc.size() == N);
|
||||
StridedMemRefType<T, N> descriptor;
|
||||
descriptor.basePtr = static_cast<T *>(ptr);
|
||||
descriptor.data = static_cast<T *>(alignedPtr);
|
||||
descriptor.offset = 0;
|
||||
std::copy(shape.begin(), shape.end(), descriptor.sizes);
|
||||
auto strides = makeStrides<N>(shapeAlloc);
|
||||
std::copy(strides.begin(), strides.end(), descriptor.strides);
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
/// Build a `StridedMemRefDescriptor<T, 0>` that matches the MLIR ABI.
|
||||
/// This is an implementation detail that is kept in sync with MLIR codegen
|
||||
/// conventions. Additionally takes a `shapeAlloc` array which
|
||||
/// is used instead of `shape` to allocate "more aligned" data and compute the
|
||||
/// corresponding strides.
|
||||
template <int N, typename T>
|
||||
typename std::enable_if<(N == 0), StridedMemRefType<T, 0>>::type
|
||||
makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape = {},
|
||||
ArrayRef<int64_t> shapeAlloc = {}) {
|
||||
assert(shape.size() == N);
|
||||
assert(shapeAlloc.size() == N);
|
||||
StridedMemRefType<T, 0> descriptor;
|
||||
descriptor.basePtr = static_cast<T *>(ptr);
|
||||
descriptor.data = static_cast<T *>(alignedPtr);
|
||||
descriptor.offset = 0;
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
/// Align `nElements` of type T with an optional `alignment`.
|
||||
/// This replaces a portable `posix_memalign`.
|
||||
/// `alignment` must be a power of 2 and greater than the size of T. By default
|
||||
/// the alignment is sizeof(T).
|
||||
template <typename T>
|
||||
std::pair<T *, T *>
|
||||
allocAligned(size_t nElements, AllocFunType allocFun = &::malloc,
|
||||
llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>()) {
|
||||
assert(sizeof(T) < (1ul << 32) && "Elemental type overflows");
|
||||
auto size = nElements * sizeof(T);
|
||||
auto desiredAlignment = alignment.getValueOr(nextPowerOf2(sizeof(T)));
|
||||
assert((desiredAlignment & (desiredAlignment - 1)) == 0);
|
||||
assert(desiredAlignment >= sizeof(T));
|
||||
T *data = reinterpret_cast<T *>(allocFun(size + desiredAlignment));
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(data);
|
||||
uintptr_t rem = addr % desiredAlignment;
|
||||
T *alignedData = (rem == 0)
|
||||
? data
|
||||
: reinterpret_cast<T *>(addr + (desiredAlignment - rem));
|
||||
assert(reinterpret_cast<uintptr_t>(alignedData) % desiredAlignment == 0);
|
||||
return std::make_pair(data, alignedData);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Public API
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Convenient callback to "visit" a memref element by element.
|
||||
/// This takes a reference to an individual element as well as the coordinates.
|
||||
/// It can be used in conjuction with a StridedMemrefIterator.
|
||||
template <typename T>
|
||||
using ElementWiseVisitor = llvm::function_ref<void(T &ptr, ArrayRef<int64_t>)>;
|
||||
|
||||
/// Owning MemRef type that abstracts over the runtime type for ranked strided
|
||||
/// memref.
|
||||
template <typename T, int Rank>
|
||||
class OwningMemRef {
|
||||
public:
|
||||
using DescriptorType = StridedMemRefType<T, Rank>;
|
||||
using FreeFunType = std::function<void(DescriptorType)>;
|
||||
|
||||
/// Allocate a new dense StridedMemrefRef with a given `shape`. An optional
|
||||
/// `shapeAlloc` array can be supplied to "pad" every dimension individually.
|
||||
/// If an ElementWiseVisitor is provided, it will be used to initialize the
|
||||
/// data, else the memory will be zero-initialized. The alloc and free method
|
||||
/// used to manage the data allocation can be optionally provided, and default
|
||||
/// to malloc/free.
|
||||
OwningMemRef(
|
||||
ArrayRef<int64_t> shape, ArrayRef<int64_t> shapeAlloc = {},
|
||||
ElementWiseVisitor<T> init = {},
|
||||
llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>(),
|
||||
AllocFunType allocFun = &::malloc,
|
||||
std::function<void(StridedMemRefType<T, Rank>)> freeFun =
|
||||
[](StridedMemRefType<T, Rank> descriptor) {
|
||||
::free(descriptor.data);
|
||||
})
|
||||
: freeFunc(freeFun) {
|
||||
if (shapeAlloc.empty())
|
||||
shapeAlloc = shape;
|
||||
assert(shape.size() == Rank);
|
||||
assert(shapeAlloc.size() == Rank);
|
||||
for (unsigned i = 0; i < Rank; ++i)
|
||||
assert(shape[i] <= shapeAlloc[i] &&
|
||||
"shapeAlloc must be greater than or equal to shape");
|
||||
int64_t nElements = 1;
|
||||
for (int64_t s : shapeAlloc)
|
||||
nElements *= s;
|
||||
T *data, *alignedData;
|
||||
std::tie(data, alignedData) =
|
||||
detail::allocAligned<T>(nElements, allocFun, alignment);
|
||||
descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
|
||||
shape, shapeAlloc);
|
||||
if (init) {
|
||||
for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
|
||||
end = descriptor.end();
|
||||
it != end; ++it)
|
||||
init(*it, it.getIndices());
|
||||
} else {
|
||||
memset(descriptor.data, 0,
|
||||
nElements * sizeof(T) +
|
||||
alignment.getValueOr(detail::nextPowerOf2(sizeof(T))));
|
||||
}
|
||||
}
|
||||
/// Take ownership of an existing descriptor with a custom deleter.
|
||||
OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc)
|
||||
: freeFunc(freeFunc), descriptor(descriptor) {}
|
||||
~OwningMemRef() {
|
||||
if (freeFunc)
|
||||
freeFunc(descriptor);
|
||||
}
|
||||
OwningMemRef(const OwningMemRef &) = delete;
|
||||
OwningMemRef &operator=(const OwningMemRef &) = delete;
|
||||
OwningMemRef &operator=(const OwningMemRef &&other) {
|
||||
freeFunc = other.freeFunc;
|
||||
descriptor = other.descriptor;
|
||||
other.freeFunc = nullptr;
|
||||
memset(0, &other.descriptor, sizeof(other.descriptor));
|
||||
}
|
||||
OwningMemRef(OwningMemRef &&other) { *this = std::move(other); }
|
||||
|
||||
DescriptorType &operator*() { return descriptor; }
|
||||
DescriptorType *operator->() { return &descriptor; }
|
||||
T &operator[](std::initializer_list<int64_t> indices) {
|
||||
return descriptor[std::move(indices)];
|
||||
}
|
||||
|
||||
private:
|
||||
/// Custom deleter used to release the data buffer manager with the descriptor
|
||||
/// below.
|
||||
FreeFunType freeFunc;
|
||||
/// The descriptor is an instance of StridedMemRefType<T, rank>.
|
||||
DescriptorType descriptor;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
|
@ -13,7 +13,6 @@
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/ExecutionEngine/CRunnerUtils.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/MemRefUtils.h"
|
||||
#include "mlir/ExecutionEngine/RunnerUtils.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
@ -90,163 +89,4 @@ TEST(MLIRExecutionEngine, SubtractFloat) {
|
||||
ASSERT_EQ(result, 42.f);
|
||||
}
|
||||
|
||||
TEST(NativeMemRefJit, ZeroRankMemref) {
|
||||
OwningMemRef<float, 0> A({});
|
||||
A[{}] = 42.;
|
||||
ASSERT_EQ(*A->data, 42);
|
||||
A[{}] = 0;
|
||||
std::string moduleStr = R"mlir(
|
||||
func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
|
||||
%cst42 = constant 42.0 : f32
|
||||
store %cst42, %arg0[] : memref<f32>
|
||||
return
|
||||
}
|
||||
)mlir";
|
||||
MLIRContext context;
|
||||
registerAllDialects(context.getDialectRegistry());
|
||||
auto module = parseSourceString(moduleStr, &context);
|
||||
ASSERT_TRUE(!!module);
|
||||
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
||||
auto jitOrError = ExecutionEngine::create(*module);
|
||||
ASSERT_TRUE(!!jitOrError);
|
||||
auto jit = std::move(jitOrError.get());
|
||||
|
||||
llvm::Error error = jit->invoke("zero_ranked", &*A);
|
||||
ASSERT_TRUE(!error);
|
||||
EXPECT_EQ((A[{}]), 42.);
|
||||
for (float &elt : *A)
|
||||
EXPECT_EQ(&elt, &(A[{}]));
|
||||
}
|
||||
|
||||
TEST(NativeMemRefJit, RankOneMemref) {
|
||||
int64_t shape[] = {9};
|
||||
OwningMemRef<float, 1> A(shape);
|
||||
int count = 1;
|
||||
for (float &elt : *A) {
|
||||
EXPECT_EQ(&elt, &(A[{count - 1}]));
|
||||
elt = count++;
|
||||
}
|
||||
|
||||
std::string moduleStr = R"mlir(
|
||||
func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
|
||||
%cst42 = constant 42.0 : f32
|
||||
%cst5 = constant 5 : index
|
||||
store %cst42, %arg0[%cst5] : memref<?xf32>
|
||||
return
|
||||
}
|
||||
)mlir";
|
||||
MLIRContext context;
|
||||
registerAllDialects(context.getDialectRegistry());
|
||||
auto module = parseSourceString(moduleStr, &context);
|
||||
ASSERT_TRUE(!!module);
|
||||
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
||||
auto jitOrError = ExecutionEngine::create(*module);
|
||||
ASSERT_TRUE(!!jitOrError);
|
||||
auto jit = std::move(jitOrError.get());
|
||||
|
||||
llvm::Error error = jit->invoke("one_ranked", &*A);
|
||||
ASSERT_TRUE(!error);
|
||||
count = 1;
|
||||
for (float &elt : *A) {
|
||||
if (count == 6)
|
||||
EXPECT_EQ(elt, 42.);
|
||||
else
|
||||
EXPECT_EQ(elt, count);
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NativeMemRefJit, BasicMemref) {
|
||||
constexpr int K = 3;
|
||||
constexpr int M = 7;
|
||||
// Prepare arguments beforehand.
|
||||
auto init = [=](float &elt, ArrayRef<int64_t> indices) {
|
||||
assert(indices.size() == 2);
|
||||
elt = M * indices[0] + indices[1];
|
||||
};
|
||||
int64_t shape[] = {K, M};
|
||||
int64_t shapeAlloc[] = {K + 1, M + 1};
|
||||
OwningMemRef<float, 2> A(shape, shapeAlloc, init);
|
||||
ASSERT_EQ(A->sizes[0], K);
|
||||
ASSERT_EQ(A->sizes[1], M);
|
||||
ASSERT_EQ(A->strides[0], M + 1);
|
||||
ASSERT_EQ(A->strides[1], 1);
|
||||
for (int i = 0; i < K; ++i)
|
||||
for (int j = 0; j < M; ++j)
|
||||
EXPECT_EQ((A[{i, j}]), i * M + j);
|
||||
|
||||
std::string moduleStr = R"mlir(
|
||||
func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
|
||||
%x = constant 2 : index
|
||||
%y = constant 1 : index
|
||||
%cst42 = constant 42.0 : f32
|
||||
store %cst42, %arg0[%y, %x] : memref<?x?xf32>
|
||||
store %cst42, %arg1[%x, %y] : memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
)mlir";
|
||||
MLIRContext context;
|
||||
registerAllDialects(context.getDialectRegistry());
|
||||
OwningModuleRef module = parseSourceString(moduleStr, &context);
|
||||
ASSERT_TRUE(!!module);
|
||||
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
||||
auto jitOrError = ExecutionEngine::create(*module);
|
||||
ASSERT_TRUE(!!jitOrError);
|
||||
std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
|
||||
|
||||
llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
|
||||
ASSERT_TRUE(!error);
|
||||
EXPECT_EQ((A[{1, 2}]), 42.);
|
||||
EXPECT_EQ((A[{2, 1}]), 42.);
|
||||
}
|
||||
|
||||
// A helper function that will be called from the JIT
|
||||
static void memref_multiply(::StridedMemRefType<float, 2> *memref,
|
||||
int32_t coefficient) {
|
||||
for (float &elt : *memref)
|
||||
elt *= coefficient;
|
||||
}
|
||||
|
||||
TEST(NativeMemRefJit, JITCallback) {
|
||||
constexpr int K = 2;
|
||||
constexpr int M = 2;
|
||||
int64_t shape[] = {K, M};
|
||||
int64_t shapeAlloc[] = {K + 1, M + 1};
|
||||
OwningMemRef<float, 2> A(shape, shapeAlloc);
|
||||
int count = 1;
|
||||
for (float &elt : *A)
|
||||
elt = count++;
|
||||
|
||||
std::string moduleStr = R"mlir(
|
||||
func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface }
|
||||
func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
|
||||
%unranked = memref_cast %arg0: memref<?x?xf32> to memref<*xf32>
|
||||
call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
|
||||
return
|
||||
}
|
||||
)mlir";
|
||||
MLIRContext context;
|
||||
registerAllDialects(context.getDialectRegistry());
|
||||
auto module = parseSourceString(moduleStr, &context);
|
||||
ASSERT_TRUE(!!module);
|
||||
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
||||
auto jitOrError = ExecutionEngine::create(*module);
|
||||
ASSERT_TRUE(!!jitOrError);
|
||||
auto jit = std::move(jitOrError.get());
|
||||
// Define any extra symbols so they're available at runtime.
|
||||
jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
|
||||
llvm::orc::SymbolMap symbolMap;
|
||||
symbolMap[interner("_mlir_ciface_callback")] =
|
||||
llvm::JITEvaluatedSymbol::fromPointer(memref_multiply);
|
||||
return symbolMap;
|
||||
});
|
||||
|
||||
int32_t coefficient = 3.;
|
||||
llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient);
|
||||
ASSERT_TRUE(!error);
|
||||
count = 1;
|
||||
for (float elt : *A)
|
||||
ASSERT_EQ(elt, coefficient * count++);
|
||||
}
|
||||
|
||||
#endif // _WIN32
|
||||
|
Loading…
Reference in New Issue
Block a user