mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-05 23:02:27 +00:00

This is part of an effort to migrate from llvm::Optional to std::optional: https://discourse.llvm.org/t/deprecating-llvm-optional-x-hasvalue-getvalue-getvalueor/63716
1734 lines
63 KiB
C++
1734 lines
63 KiB
C++
//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TODO: Support for big-endian architectures.
|
|
// TODO: Properly preserve use lists of values.
|
|
|
|
#include "mlir/Bytecode/BytecodeReader.h"
|
|
#include "../Encoding.h"
|
|
#include "mlir/AsmParser/AsmParser.h"
|
|
#include "mlir/Bytecode/BytecodeImplementation.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/ADT/ScopeExit.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/MemoryBufferRef.h"
|
|
#include "llvm/Support/SaveAndRestore.h"
|
|
#include <optional>
|
|
|
|
#define DEBUG_TYPE "mlir-bytecode-reader"
|
|
|
|
using namespace mlir;
|
|
|
|
/// Stringify the given section ID.
|
|
static std::string toString(bytecode::Section::ID sectionID) {
|
|
switch (sectionID) {
|
|
case bytecode::Section::kString:
|
|
return "String (0)";
|
|
case bytecode::Section::kDialect:
|
|
return "Dialect (1)";
|
|
case bytecode::Section::kAttrType:
|
|
return "AttrType (2)";
|
|
case bytecode::Section::kAttrTypeOffset:
|
|
return "AttrTypeOffset (3)";
|
|
case bytecode::Section::kIR:
|
|
return "IR (4)";
|
|
case bytecode::Section::kResource:
|
|
return "Resource (5)";
|
|
case bytecode::Section::kResourceOffset:
|
|
return "ResourceOffset (6)";
|
|
default:
|
|
return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
|
|
}
|
|
}
|
|
|
|
/// Returns true if the given top-level section ID is optional.
|
|
static bool isSectionOptional(bytecode::Section::ID sectionID) {
|
|
switch (sectionID) {
|
|
case bytecode::Section::kString:
|
|
case bytecode::Section::kDialect:
|
|
case bytecode::Section::kAttrType:
|
|
case bytecode::Section::kAttrTypeOffset:
|
|
case bytecode::Section::kIR:
|
|
return false;
|
|
case bytecode::Section::kResource:
|
|
case bytecode::Section::kResourceOffset:
|
|
return true;
|
|
default:
|
|
llvm_unreachable("unknown section ID");
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// EncodingReader
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class EncodingReader {
|
|
public:
|
|
explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
|
|
: dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {}
|
|
explicit EncodingReader(StringRef contents, Location fileLoc)
|
|
: EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
|
|
contents.size()},
|
|
fileLoc) {}
|
|
|
|
/// Returns true if the entire section has been read.
|
|
bool empty() const { return dataIt == dataEnd; }
|
|
|
|
/// Returns the remaining size of the bytecode.
|
|
size_t size() const { return dataEnd - dataIt; }
|
|
|
|
/// Align the current reader position to the specified alignment.
|
|
LogicalResult alignTo(unsigned alignment) {
|
|
if (!llvm::isPowerOf2_32(alignment))
|
|
return emitError("expected alignment to be a power-of-two");
|
|
|
|
// Shift the reader position to the next alignment boundary.
|
|
while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) {
|
|
uint8_t padding;
|
|
if (failed(parseByte(padding)))
|
|
return failure();
|
|
if (padding != bytecode::kAlignmentByte) {
|
|
return emitError("expected alignment byte (0xCB), but got: '0x" +
|
|
llvm::utohexstr(padding) + "'");
|
|
}
|
|
}
|
|
|
|
// Ensure the data iterator is now aligned. This case is unlikely because we
|
|
// *just* went through the effort to align the data iterator.
|
|
if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) {
|
|
return emitError("expected data iterator aligned to ", alignment,
|
|
", but got pointer: '0x" +
|
|
llvm::utohexstr((uintptr_t)dataIt) + "'");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Emit an error using the given arguments.
|
|
template <typename... Args>
|
|
InFlightDiagnostic emitError(Args &&...args) const {
|
|
return ::emitError(fileLoc).append(std::forward<Args>(args)...);
|
|
}
|
|
InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
|
|
|
|
/// Parse a single byte from the stream.
|
|
template <typename T>
|
|
LogicalResult parseByte(T &value) {
|
|
if (empty())
|
|
return emitError("attempting to parse a byte at the end of the bytecode");
|
|
value = static_cast<T>(*dataIt++);
|
|
return success();
|
|
}
|
|
/// Parse a range of bytes of 'length' into the given result.
|
|
LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
|
|
if (length > size()) {
|
|
return emitError("attempting to parse ", length, " bytes when only ",
|
|
size(), " remain");
|
|
}
|
|
result = {dataIt, length};
|
|
dataIt += length;
|
|
return success();
|
|
}
|
|
/// Parse a range of bytes of 'length' into the given result, which can be
|
|
/// assumed to be large enough to hold `length`.
|
|
LogicalResult parseBytes(size_t length, uint8_t *result) {
|
|
if (length > size()) {
|
|
return emitError("attempting to parse ", length, " bytes when only ",
|
|
size(), " remain");
|
|
}
|
|
memcpy(result, dataIt, length);
|
|
dataIt += length;
|
|
return success();
|
|
}
|
|
|
|
/// Parse an aligned blob of data, where the alignment was encoded alongside
|
|
/// the data.
|
|
LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
|
|
uint64_t &alignment) {
|
|
uint64_t dataSize;
|
|
if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
|
|
failed(alignTo(alignment)))
|
|
return failure();
|
|
return parseBytes(dataSize, data);
|
|
}
|
|
|
|
/// Parse a variable length encoded integer from the byte stream. The first
|
|
/// encoded byte contains a prefix in the low bits indicating the encoded
|
|
/// length of the value. This length prefix is a bit sequence of '0's followed
|
|
/// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
|
|
/// (not including the prefix byte). All remaining bits in the first byte,
|
|
/// along with all of the bits in additional bytes, provide the value of the
|
|
/// integer encoded in little-endian order.
|
|
LogicalResult parseVarInt(uint64_t &result) {
|
|
// Parse the first byte of the encoding, which contains the length prefix.
|
|
if (failed(parseByte(result)))
|
|
return failure();
|
|
|
|
// Handle the overwhelmingly common case where the value is stored in a
|
|
// single byte. In this case, the first bit is the `1` marker bit.
|
|
if (LLVM_LIKELY(result & 1)) {
|
|
result >>= 1;
|
|
return success();
|
|
}
|
|
|
|
// Handle the overwhelming uncommon case where the value required all 8
|
|
// bytes (i.e. a really really big number). In this case, the marker byte is
|
|
// all zeros: `00000000`.
|
|
if (LLVM_UNLIKELY(result == 0))
|
|
return parseBytes(sizeof(result), reinterpret_cast<uint8_t *>(&result));
|
|
return parseMultiByteVarInt(result);
|
|
}
|
|
|
|
/// Parse a signed variable length encoded integer from the byte stream. A
|
|
/// signed varint is encoded as a normal varint with zigzag encoding applied,
|
|
/// i.e. the low bit of the value is used to indicate the sign.
|
|
LogicalResult parseSignedVarInt(uint64_t &result) {
|
|
if (failed(parseVarInt(result)))
|
|
return failure();
|
|
// Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
|
|
result = (result >> 1) ^ (~(result & 1) + 1);
|
|
return success();
|
|
}
|
|
|
|
/// Parse a variable length encoded integer whose low bit is used to encode an
|
|
/// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
|
|
LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
|
|
if (failed(parseVarInt(result)))
|
|
return failure();
|
|
flag = result & 1;
|
|
result >>= 1;
|
|
return success();
|
|
}
|
|
|
|
/// Skip the first `length` bytes within the reader.
|
|
LogicalResult skipBytes(size_t length) {
|
|
if (length > size()) {
|
|
return emitError("attempting to skip ", length, " bytes when only ",
|
|
size(), " remain");
|
|
}
|
|
dataIt += length;
|
|
return success();
|
|
}
|
|
|
|
/// Parse a null-terminated string into `result` (without including the NUL
|
|
/// terminator).
|
|
LogicalResult parseNullTerminatedString(StringRef &result) {
|
|
const char *startIt = (const char *)dataIt;
|
|
const char *nulIt = (const char *)memchr(startIt, 0, size());
|
|
if (!nulIt)
|
|
return emitError(
|
|
"malformed null-terminated string, no null character found");
|
|
|
|
result = StringRef(startIt, nulIt - startIt);
|
|
dataIt = (const uint8_t *)nulIt + 1;
|
|
return success();
|
|
}
|
|
|
|
/// Parse a section header, placing the kind of section in `sectionID` and the
|
|
/// contents of the section in `sectionData`.
|
|
LogicalResult parseSection(bytecode::Section::ID §ionID,
|
|
ArrayRef<uint8_t> §ionData) {
|
|
uint8_t sectionIDAndHasAlignment;
|
|
uint64_t length;
|
|
if (failed(parseByte(sectionIDAndHasAlignment)) ||
|
|
failed(parseVarInt(length)))
|
|
return failure();
|
|
|
|
// Extract the section ID and whether the section is aligned. The high bit
|
|
// of the ID is the alignment flag.
|
|
sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
|
|
0b01111111);
|
|
bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
|
|
|
|
// Check that the section is actually valid before trying to process its
|
|
// data.
|
|
if (sectionID >= bytecode::Section::kNumSections)
|
|
return emitError("invalid section ID: ", unsigned(sectionID));
|
|
|
|
// Process the section alignment if present.
|
|
if (hasAlignment) {
|
|
uint64_t alignment;
|
|
if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
|
|
return failure();
|
|
}
|
|
|
|
// Parse the actual section data.
|
|
return parseBytes(static_cast<size_t>(length), sectionData);
|
|
}
|
|
|
|
private:
|
|
/// Parse a variable length encoded integer from the byte stream. This method
|
|
/// is a fallback when the number of bytes used to encode the value is greater
|
|
/// than 1, but less than the max (9). The provided `result` value can be
|
|
/// assumed to already contain the first byte of the value.
|
|
/// NOTE: This method is marked noinline to avoid pessimizing the common case
|
|
/// of single byte encoding.
|
|
LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
|
|
// Count the number of trailing zeros in the marker byte, this indicates the
|
|
// number of trailing bytes that are part of the value. We use `uint32_t`
|
|
// here because we only care about the first byte, and so that be actually
|
|
// get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
|
|
// implementation).
|
|
uint32_t numBytes =
|
|
llvm::countTrailingZeros<uint32_t>(result, llvm::ZB_Undefined);
|
|
assert(numBytes > 0 && numBytes <= 7 &&
|
|
"unexpected number of trailing zeros in varint encoding");
|
|
|
|
// Parse in the remaining bytes of the value.
|
|
if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&result) + 1)))
|
|
return failure();
|
|
|
|
// Shift out the low-order bits that were used to mark how the value was
|
|
// encoded.
|
|
result >>= (numBytes + 1);
|
|
return success();
|
|
}
|
|
|
|
/// The current data iterator, and an iterator to the end of the buffer.
|
|
const uint8_t *dataIt, *dataEnd;
|
|
|
|
/// A location for the bytecode used to report errors.
|
|
Location fileLoc;
|
|
};
|
|
} // namespace
|
|
|
|
/// Resolve an index into the given entry list. `entry` may either be a
|
|
/// reference, in which case it is assigned to the corresponding value in
|
|
/// `entries`, or a pointer, in which case it is assigned to the address of the
|
|
/// element in `entries`.
|
|
template <typename RangeT, typename T>
|
|
static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
|
|
uint64_t index, T &entry,
|
|
StringRef entryStr) {
|
|
if (index >= entries.size())
|
|
return reader.emitError("invalid ", entryStr, " index: ", index);
|
|
|
|
// If the provided entry is a pointer, resolve to the address of the entry.
|
|
if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
|
|
entry = entries[index];
|
|
else
|
|
entry = &entries[index];
|
|
return success();
|
|
}
|
|
|
|
/// Parse and resolve an index into the given entry list.
|
|
template <typename RangeT, typename T>
|
|
static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
|
|
T &entry, StringRef entryStr) {
|
|
uint64_t entryIdx;
|
|
if (failed(reader.parseVarInt(entryIdx)))
|
|
return failure();
|
|
return resolveEntry(reader, entries, entryIdx, entry, entryStr);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StringSectionReader
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class is used to read references to the string section from the
|
|
/// bytecode.
|
|
class StringSectionReader {
|
|
public:
|
|
/// Initialize the string section reader with the given section data.
|
|
LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
|
|
|
|
/// Parse a shared string from the string section. The shared string is
|
|
/// encoded using an index to a corresponding string in the string section.
|
|
LogicalResult parseString(EncodingReader &reader, StringRef &result) {
|
|
return parseEntry(reader, strings, result, "string");
|
|
}
|
|
|
|
private:
|
|
/// The table of strings referenced within the bytecode file.
|
|
SmallVector<StringRef> strings;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult StringSectionReader::initialize(Location fileLoc,
|
|
ArrayRef<uint8_t> sectionData) {
|
|
EncodingReader stringReader(sectionData, fileLoc);
|
|
|
|
// Parse the number of strings in the section.
|
|
uint64_t numStrings;
|
|
if (failed(stringReader.parseVarInt(numStrings)))
|
|
return failure();
|
|
strings.resize(numStrings);
|
|
|
|
// Parse each of the strings. The sizes of the strings are encoded in reverse
|
|
// order, so that's the order we populate the table.
|
|
size_t stringDataEndOffset = sectionData.size();
|
|
for (StringRef &string : llvm::reverse(strings)) {
|
|
uint64_t stringSize;
|
|
if (failed(stringReader.parseVarInt(stringSize)))
|
|
return failure();
|
|
if (stringDataEndOffset < stringSize) {
|
|
return stringReader.emitError(
|
|
"string size exceeds the available data size");
|
|
}
|
|
|
|
// Extract the string from the data, dropping the null character.
|
|
size_t stringOffset = stringDataEndOffset - stringSize;
|
|
string = StringRef(
|
|
reinterpret_cast<const char *>(sectionData.data() + stringOffset),
|
|
stringSize - 1);
|
|
stringDataEndOffset = stringOffset;
|
|
}
|
|
|
|
// Check that the only remaining data was for the strings, i.e. the reader
|
|
// should be at the same offset as the first string.
|
|
if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
|
|
return stringReader.emitError("unexpected trailing data between the "
|
|
"offsets for strings and their data");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BytecodeDialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This struct represents a dialect entry within the bytecode.
|
|
struct BytecodeDialect {
|
|
/// Load the dialect into the provided context if it hasn't been loaded yet.
|
|
/// Returns failure if the dialect couldn't be loaded *and* the provided
|
|
/// context does not allow unregistered dialects. The provided reader is used
|
|
/// for error emission if necessary.
|
|
LogicalResult load(EncodingReader &reader, MLIRContext *ctx) {
|
|
if (dialect)
|
|
return success();
|
|
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
|
|
if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
|
|
return reader.emitError(
|
|
"dialect '", name,
|
|
"' is unknown. If this is intended, please call "
|
|
"allowUnregisteredDialects() on the MLIRContext, or use "
|
|
"-allow-unregistered-dialect with the MLIR tool used.");
|
|
}
|
|
dialect = loadedDialect;
|
|
|
|
// If the dialect was actually loaded, check to see if it has a bytecode
|
|
// interface.
|
|
if (loadedDialect)
|
|
interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
|
|
return success();
|
|
}
|
|
|
|
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
|
|
/// only be called after `load`.
|
|
Dialect *getLoadedDialect() const {
|
|
assert(dialect &&
|
|
"expected `load` to be invoked before `getLoadedDialect`");
|
|
return *dialect;
|
|
}
|
|
|
|
/// The loaded dialect entry. This field is None if we haven't attempted to
|
|
/// load, nullptr if we failed to load, otherwise the loaded dialect.
|
|
std::optional<Dialect *> dialect;
|
|
|
|
/// The bytecode interface of the dialect, or nullptr if the dialect does not
|
|
/// implement the bytecode interface. This field should only be checked if the
|
|
/// `dialect` field is not std::nullopt.
|
|
const BytecodeDialectInterface *interface = nullptr;
|
|
|
|
/// The name of the dialect.
|
|
StringRef name;
|
|
};
|
|
|
|
/// This struct represents an operation name entry within the bytecode.
|
|
struct BytecodeOperationName {
|
|
BytecodeOperationName(BytecodeDialect *dialect, StringRef name)
|
|
: dialect(dialect), name(name) {}
|
|
|
|
/// The loaded operation name, or std::nullopt if it hasn't been processed
|
|
/// yet.
|
|
std::optional<OperationName> opName;
|
|
|
|
/// The dialect that owns this operation name.
|
|
BytecodeDialect *dialect;
|
|
|
|
/// The name of the operation, without the dialect prefix.
|
|
StringRef name;
|
|
};
|
|
} // namespace
|
|
|
|
/// Parse a single dialect group encoded in the byte stream.
|
|
static LogicalResult parseDialectGrouping(
|
|
EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
|
|
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
|
|
// Parse the dialect and the number of entries in the group.
|
|
BytecodeDialect *dialect;
|
|
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
|
|
return failure();
|
|
uint64_t numEntries;
|
|
if (failed(reader.parseVarInt(numEntries)))
|
|
return failure();
|
|
|
|
for (uint64_t i = 0; i < numEntries; ++i)
|
|
if (failed(entryCallback(dialect)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ResourceSectionReader
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class is used to read the resource section from the bytecode.
|
|
class ResourceSectionReader {
|
|
public:
|
|
/// Initialize the resource section reader with the given section data.
|
|
LogicalResult initialize(Location fileLoc, const ParserConfig &config,
|
|
MutableArrayRef<BytecodeDialect> dialects,
|
|
StringSectionReader &stringReader,
|
|
ArrayRef<uint8_t> sectionData,
|
|
ArrayRef<uint8_t> offsetSectionData);
|
|
|
|
/// Parse a dialect resource handle from the resource section.
|
|
LogicalResult parseResourceHandle(EncodingReader &reader,
|
|
AsmDialectResourceHandle &result) {
|
|
return parseEntry(reader, dialectResources, result, "resource handle");
|
|
}
|
|
|
|
private:
|
|
/// The table of dialect resources within the bytecode file.
|
|
SmallVector<AsmDialectResourceHandle> dialectResources;
|
|
};
|
|
|
|
class ParsedResourceEntry : public AsmParsedResourceEntry {
|
|
public:
|
|
ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
|
|
EncodingReader &reader, StringSectionReader &stringReader)
|
|
: key(key), kind(kind), reader(reader), stringReader(stringReader) {}
|
|
~ParsedResourceEntry() override = default;
|
|
|
|
StringRef getKey() const final { return key; }
|
|
|
|
InFlightDiagnostic emitError() const final { return reader.emitError(); }
|
|
|
|
AsmResourceEntryKind getKind() const final { return kind; }
|
|
|
|
FailureOr<bool> parseAsBool() const final {
|
|
if (kind != AsmResourceEntryKind::Bool)
|
|
return emitError() << "expected a bool resource entry, but found a "
|
|
<< toString(kind) << " entry instead";
|
|
|
|
bool value;
|
|
if (failed(reader.parseByte(value)))
|
|
return failure();
|
|
return value;
|
|
}
|
|
FailureOr<std::string> parseAsString() const final {
|
|
if (kind != AsmResourceEntryKind::String)
|
|
return emitError() << "expected a string resource entry, but found a "
|
|
<< toString(kind) << " entry instead";
|
|
|
|
StringRef string;
|
|
if (failed(stringReader.parseString(reader, string)))
|
|
return failure();
|
|
return string.str();
|
|
}
|
|
|
|
FailureOr<AsmResourceBlob>
|
|
parseAsBlob(BlobAllocatorFn allocator) const final {
|
|
if (kind != AsmResourceEntryKind::Blob)
|
|
return emitError() << "expected a blob resource entry, but found a "
|
|
<< toString(kind) << " entry instead";
|
|
|
|
ArrayRef<uint8_t> data;
|
|
uint64_t alignment;
|
|
if (failed(reader.parseBlobAndAlignment(data, alignment)))
|
|
return failure();
|
|
|
|
// Allocate memory for the blob using the provided allocator and copy the
|
|
// data into it.
|
|
// FIXME: If the current holder of the bytecode can ensure its lifetime
|
|
// (e.g. when mmap'd), we should not copy the data. We should use the data
|
|
// from the bytecode directly.
|
|
AsmResourceBlob blob = allocator(data.size(), alignment);
|
|
assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
|
|
blob.isMutable() &&
|
|
"blob allocator did not return a properly aligned address");
|
|
memcpy(blob.getMutableData().data(), data.data(), data.size());
|
|
return blob;
|
|
}
|
|
|
|
private:
|
|
StringRef key;
|
|
AsmResourceEntryKind kind;
|
|
EncodingReader &reader;
|
|
StringSectionReader &stringReader;
|
|
};
|
|
} // namespace
|
|
|
|
template <typename T>
|
|
static LogicalResult
|
|
parseResourceGroup(Location fileLoc, bool allowEmpty,
|
|
EncodingReader &offsetReader, EncodingReader &resourceReader,
|
|
StringSectionReader &stringReader, T *handler,
|
|
function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
|
|
uint64_t numResources;
|
|
if (failed(offsetReader.parseVarInt(numResources)))
|
|
return failure();
|
|
|
|
for (uint64_t i = 0; i < numResources; ++i) {
|
|
StringRef key;
|
|
AsmResourceEntryKind kind;
|
|
uint64_t resourceOffset;
|
|
ArrayRef<uint8_t> data;
|
|
if (failed(stringReader.parseString(offsetReader, key)) ||
|
|
failed(offsetReader.parseVarInt(resourceOffset)) ||
|
|
failed(offsetReader.parseByte(kind)) ||
|
|
failed(resourceReader.parseBytes(resourceOffset, data)))
|
|
return failure();
|
|
|
|
// Process the resource key.
|
|
if ((processKeyFn && failed(processKeyFn(key))))
|
|
return failure();
|
|
|
|
// If the resource data is empty and we allow it, don't error out when
|
|
// parsing below, just skip it.
|
|
if (allowEmpty && data.empty())
|
|
continue;
|
|
|
|
// Ignore the entry if we don't have a valid handler.
|
|
if (!handler)
|
|
continue;
|
|
|
|
// Otherwise, parse the resource value.
|
|
EncodingReader entryReader(data, fileLoc);
|
|
ParsedResourceEntry entry(key, kind, entryReader, stringReader);
|
|
if (failed(handler->parseResource(entry)))
|
|
return failure();
|
|
if (!entryReader.empty()) {
|
|
return entryReader.emitError(
|
|
"unexpected trailing bytes in resource entry '", key, "'");
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config,
|
|
MutableArrayRef<BytecodeDialect> dialects,
|
|
StringSectionReader &stringReader,
|
|
ArrayRef<uint8_t> sectionData,
|
|
ArrayRef<uint8_t> offsetSectionData) {
|
|
EncodingReader resourceReader(sectionData, fileLoc);
|
|
EncodingReader offsetReader(offsetSectionData, fileLoc);
|
|
|
|
// Read the number of external resource providers.
|
|
uint64_t numExternalResourceGroups;
|
|
if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
|
|
return failure();
|
|
|
|
// Utility functor that dispatches to `parseResourceGroup`, but implicitly
|
|
// provides most of the arguments.
|
|
auto parseGroup = [&](auto *handler, bool allowEmpty = false,
|
|
function_ref<LogicalResult(StringRef)> keyFn = {}) {
|
|
return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
|
|
stringReader, handler, keyFn);
|
|
};
|
|
|
|
// Read the external resources from the bytecode.
|
|
for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
|
|
StringRef key;
|
|
if (failed(stringReader.parseString(offsetReader, key)))
|
|
return failure();
|
|
|
|
// Get the handler for these resources.
|
|
// TODO: Should we require handling external resources in some scenarios?
|
|
AsmResourceParser *handler = config.getResourceParser(key);
|
|
if (!handler) {
|
|
emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
|
|
<< "'";
|
|
}
|
|
|
|
if (failed(parseGroup(handler)))
|
|
return failure();
|
|
}
|
|
|
|
// Read the dialect resources from the bytecode.
|
|
MLIRContext *ctx = fileLoc->getContext();
|
|
while (!offsetReader.empty()) {
|
|
BytecodeDialect *dialect;
|
|
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
|
|
failed(dialect->load(resourceReader, ctx)))
|
|
return failure();
|
|
Dialect *loadedDialect = dialect->getLoadedDialect();
|
|
if (!loadedDialect) {
|
|
return resourceReader.emitError()
|
|
<< "dialect '" << dialect->name << "' is unknown";
|
|
}
|
|
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
|
|
if (!handler) {
|
|
return resourceReader.emitError()
|
|
<< "unexpected resources for dialect '" << dialect->name << "'";
|
|
}
|
|
|
|
// Ensure that each resource is declared before being processed.
|
|
auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
|
|
FailureOr<AsmDialectResourceHandle> handle =
|
|
handler->declareResource(key);
|
|
if (failed(handle)) {
|
|
return resourceReader.emitError()
|
|
<< "unknown 'resource' key '" << key << "' for dialect '"
|
|
<< dialect->name << "'";
|
|
}
|
|
dialectResources.push_back(*handle);
|
|
return success();
|
|
};
|
|
|
|
// Parse the resources for this dialect. We allow empty resources because we
|
|
// just treat these as declarations.
|
|
if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Attribute/Type Reader
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class provides support for reading attribute and type entries from the
|
|
/// bytecode. Attribute and Type entries are read lazily on demand, so we use
|
|
/// this reader to manage when to actually parse them from the bytecode.
|
|
class AttrTypeReader {
|
|
/// This class represents a single attribute or type entry.
|
|
template <typename T>
|
|
struct Entry {
|
|
/// The entry, or null if it hasn't been resolved yet.
|
|
T entry = {};
|
|
/// The parent dialect of this entry.
|
|
BytecodeDialect *dialect = nullptr;
|
|
/// A flag indicating if the entry was encoded using a custom encoding,
|
|
/// instead of using the textual assembly format.
|
|
bool hasCustomEncoding = false;
|
|
/// The raw data of this entry in the bytecode.
|
|
ArrayRef<uint8_t> data;
|
|
};
|
|
using AttrEntry = Entry<Attribute>;
|
|
using TypeEntry = Entry<Type>;
|
|
|
|
public:
|
|
AttrTypeReader(StringSectionReader &stringReader,
|
|
ResourceSectionReader &resourceReader, Location fileLoc)
|
|
: stringReader(stringReader), resourceReader(resourceReader),
|
|
fileLoc(fileLoc) {}
|
|
|
|
/// Initialize the attribute and type information within the reader.
|
|
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
|
|
ArrayRef<uint8_t> sectionData,
|
|
ArrayRef<uint8_t> offsetSectionData);
|
|
|
|
/// Resolve the attribute or type at the given index. Returns nullptr on
|
|
/// failure.
|
|
Attribute resolveAttribute(size_t index) {
|
|
return resolveEntry(attributes, index, "Attribute");
|
|
}
|
|
Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
|
|
|
|
/// Parse a reference to an attribute or type using the given reader.
|
|
LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
|
|
uint64_t attrIdx;
|
|
if (failed(reader.parseVarInt(attrIdx)))
|
|
return failure();
|
|
result = resolveAttribute(attrIdx);
|
|
return success(!!result);
|
|
}
|
|
LogicalResult parseType(EncodingReader &reader, Type &result) {
|
|
uint64_t typeIdx;
|
|
if (failed(reader.parseVarInt(typeIdx)))
|
|
return failure();
|
|
result = resolveType(typeIdx);
|
|
return success(!!result);
|
|
}
|
|
|
|
template <typename T>
|
|
LogicalResult parseAttribute(EncodingReader &reader, T &result) {
|
|
Attribute baseResult;
|
|
if (failed(parseAttribute(reader, baseResult)))
|
|
return failure();
|
|
if ((result = baseResult.dyn_cast<T>()))
|
|
return success();
|
|
return reader.emitError("expected attribute of type: ",
|
|
llvm::getTypeName<T>(), ", but got: ", baseResult);
|
|
}
|
|
|
|
private:
|
|
/// Resolve the given entry at `index`.
|
|
template <typename T>
|
|
T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
|
|
StringRef entryType);
|
|
|
|
/// Parse an entry using the given reader that was encoded using the textual
|
|
/// assembly format.
|
|
template <typename T>
|
|
LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
|
|
StringRef entryType);
|
|
|
|
/// Parse an entry using the given reader that was encoded using a custom
|
|
/// bytecode format.
|
|
template <typename T>
|
|
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
|
|
StringRef entryType);
|
|
|
|
/// The string section reader used to resolve string references when parsing
|
|
/// custom encoded attribute/type entries.
|
|
StringSectionReader &stringReader;
|
|
|
|
/// The resource section reader used to resolve resource references when
|
|
/// parsing custom encoded attribute/type entries.
|
|
ResourceSectionReader &resourceReader;
|
|
|
|
/// The set of attribute and type entries.
|
|
SmallVector<AttrEntry> attributes;
|
|
SmallVector<TypeEntry> types;
|
|
|
|
/// A location used for error emission.
|
|
Location fileLoc;
|
|
};
|
|
|
|
class DialectReader : public DialectBytecodeReader {
|
|
public:
|
|
DialectReader(AttrTypeReader &attrTypeReader,
|
|
StringSectionReader &stringReader,
|
|
ResourceSectionReader &resourceReader, EncodingReader &reader)
|
|
: attrTypeReader(attrTypeReader), stringReader(stringReader),
|
|
resourceReader(resourceReader), reader(reader) {}
|
|
|
|
InFlightDiagnostic emitError(const Twine &msg) override {
|
|
return reader.emitError(msg);
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// IR
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
LogicalResult readAttribute(Attribute &result) override {
|
|
return attrTypeReader.parseAttribute(reader, result);
|
|
}
|
|
|
|
LogicalResult readType(Type &result) override {
|
|
return attrTypeReader.parseType(reader, result);
|
|
}
|
|
|
|
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
|
|
AsmDialectResourceHandle handle;
|
|
if (failed(resourceReader.parseResourceHandle(reader, handle)))
|
|
return failure();
|
|
return handle;
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Primitives
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
LogicalResult readVarInt(uint64_t &result) override {
|
|
return reader.parseVarInt(result);
|
|
}
|
|
|
|
LogicalResult readSignedVarInt(int64_t &result) override {
|
|
uint64_t unsignedResult;
|
|
if (failed(reader.parseSignedVarInt(unsignedResult)))
|
|
return failure();
|
|
result = static_cast<int64_t>(unsignedResult);
|
|
return success();
|
|
}
|
|
|
|
FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
|
|
// Small values are encoded using a single byte.
|
|
if (bitWidth <= 8) {
|
|
uint8_t value;
|
|
if (failed(reader.parseByte(value)))
|
|
return failure();
|
|
return APInt(bitWidth, value);
|
|
}
|
|
|
|
// Large values up to 64 bits are encoded using a single varint.
|
|
if (bitWidth <= 64) {
|
|
uint64_t value;
|
|
if (failed(reader.parseSignedVarInt(value)))
|
|
return failure();
|
|
return APInt(bitWidth, value);
|
|
}
|
|
|
|
// Otherwise, for really big values we encode the array of active words in
|
|
// the value.
|
|
uint64_t numActiveWords;
|
|
if (failed(reader.parseVarInt(numActiveWords)))
|
|
return failure();
|
|
SmallVector<uint64_t, 4> words(numActiveWords);
|
|
for (uint64_t i = 0; i < numActiveWords; ++i)
|
|
if (failed(reader.parseSignedVarInt(words[i])))
|
|
return failure();
|
|
return APInt(bitWidth, words);
|
|
}
|
|
|
|
FailureOr<APFloat>
|
|
readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
|
|
FailureOr<APInt> intVal =
|
|
readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
|
|
if (failed(intVal))
|
|
return failure();
|
|
return APFloat(semantics, *intVal);
|
|
}
|
|
|
|
LogicalResult readString(StringRef &result) override {
|
|
return stringReader.parseString(reader, result);
|
|
}
|
|
|
|
LogicalResult readBlob(ArrayRef<char> &result) override {
|
|
uint64_t dataSize;
|
|
ArrayRef<uint8_t> data;
|
|
if (failed(reader.parseVarInt(dataSize)) ||
|
|
failed(reader.parseBytes(dataSize, data)))
|
|
return failure();
|
|
result = llvm::makeArrayRef(reinterpret_cast<const char *>(data.data()),
|
|
data.size());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
AttrTypeReader &attrTypeReader;
|
|
StringSectionReader &stringReader;
|
|
ResourceSectionReader &resourceReader;
|
|
EncodingReader &reader;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult
|
|
AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
|
|
ArrayRef<uint8_t> sectionData,
|
|
ArrayRef<uint8_t> offsetSectionData) {
|
|
EncodingReader offsetReader(offsetSectionData, fileLoc);
|
|
|
|
// Parse the number of attribute and type entries.
|
|
uint64_t numAttributes, numTypes;
|
|
if (failed(offsetReader.parseVarInt(numAttributes)) ||
|
|
failed(offsetReader.parseVarInt(numTypes)))
|
|
return failure();
|
|
attributes.resize(numAttributes);
|
|
types.resize(numTypes);
|
|
|
|
// A functor used to accumulate the offsets for the entries in the given
|
|
// range.
|
|
uint64_t currentOffset = 0;
|
|
auto parseEntries = [&](auto &&range) {
|
|
size_t currentIndex = 0, endIndex = range.size();
|
|
|
|
// Parse an individual entry.
|
|
auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
|
|
auto &entry = range[currentIndex++];
|
|
|
|
uint64_t entrySize;
|
|
if (failed(offsetReader.parseVarIntWithFlag(entrySize,
|
|
entry.hasCustomEncoding)))
|
|
return failure();
|
|
|
|
// Verify that the offset is actually valid.
|
|
if (currentOffset + entrySize > sectionData.size()) {
|
|
return offsetReader.emitError(
|
|
"Attribute or Type entry offset points past the end of section");
|
|
}
|
|
|
|
entry.data = sectionData.slice(currentOffset, entrySize);
|
|
entry.dialect = dialect;
|
|
currentOffset += entrySize;
|
|
return success();
|
|
};
|
|
while (currentIndex != endIndex)
|
|
if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
|
|
return failure();
|
|
return success();
|
|
};
|
|
|
|
// Process each of the attributes, and then the types.
|
|
if (failed(parseEntries(attributes)) || failed(parseEntries(types)))
|
|
return failure();
|
|
|
|
// Ensure that we read everything from the section.
|
|
if (!offsetReader.empty()) {
|
|
return offsetReader.emitError(
|
|
"unexpected trailing data in the Attribute/Type offset section");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
template <typename T>
|
|
T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
|
|
StringRef entryType) {
|
|
if (index >= entries.size()) {
|
|
emitError(fileLoc) << "invalid " << entryType << " index: " << index;
|
|
return {};
|
|
}
|
|
|
|
// If the entry has already been resolved, there is nothing left to do.
|
|
Entry<T> &entry = entries[index];
|
|
if (entry.entry)
|
|
return entry.entry;
|
|
|
|
// Parse the entry.
|
|
EncodingReader reader(entry.data, fileLoc);
|
|
|
|
// Parse based on how the entry was encoded.
|
|
if (entry.hasCustomEncoding) {
|
|
if (failed(parseCustomEntry(entry, reader, entryType)))
|
|
return T();
|
|
} else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
|
|
return T();
|
|
}
|
|
|
|
if (!reader.empty()) {
|
|
reader.emitError("unexpected trailing bytes after " + entryType + " entry");
|
|
return T();
|
|
}
|
|
return entry.entry;
|
|
}
|
|
|
|
template <typename T>
|
|
LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
|
|
StringRef entryType) {
|
|
StringRef asmStr;
|
|
if (failed(reader.parseNullTerminatedString(asmStr)))
|
|
return failure();
|
|
|
|
// Invoke the MLIR assembly parser to parse the entry text.
|
|
size_t numRead = 0;
|
|
MLIRContext *context = fileLoc->getContext();
|
|
if constexpr (std::is_same_v<T, Type>)
|
|
result = ::parseType(asmStr, context, numRead);
|
|
else
|
|
result = ::parseAttribute(asmStr, context, numRead);
|
|
if (!result)
|
|
return failure();
|
|
|
|
// Ensure there weren't dangling characters after the entry.
|
|
if (numRead != asmStr.size()) {
|
|
return reader.emitError("trailing characters found after ", entryType,
|
|
" assembly format: ", asmStr.drop_front(numRead));
|
|
}
|
|
return success();
|
|
}
|
|
|
|
template <typename T>
|
|
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
|
|
EncodingReader &reader,
|
|
StringRef entryType) {
|
|
if (failed(entry.dialect->load(reader, fileLoc.getContext())))
|
|
return failure();
|
|
|
|
// Ensure that the dialect implements the bytecode interface.
|
|
if (!entry.dialect->interface) {
|
|
return reader.emitError("dialect '", entry.dialect->name,
|
|
"' does not implement the bytecode interface");
|
|
}
|
|
|
|
// Ask the dialect to parse the entry.
|
|
DialectReader dialectReader(*this, stringReader, resourceReader, reader);
|
|
if constexpr (std::is_same_v<T, Type>)
|
|
entry.entry = entry.dialect->interface->readType(dialectReader);
|
|
else
|
|
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
|
|
return success(!!entry.entry);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Bytecode Reader
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class is used to read a bytecode buffer and translate it into MLIR.
|
|
class BytecodeReader {
|
|
public:
|
|
BytecodeReader(Location fileLoc, const ParserConfig &config)
|
|
: config(config), fileLoc(fileLoc),
|
|
attrTypeReader(stringReader, resourceReader, fileLoc),
|
|
// Use the builtin unrealized conversion cast operation to represent
|
|
// forward references to values that aren't yet defined.
|
|
forwardRefOpState(UnknownLoc::get(config.getContext()),
|
|
"builtin.unrealized_conversion_cast", ValueRange(),
|
|
NoneType::get(config.getContext())) {}
|
|
|
|
/// Read the bytecode defined within `buffer` into the given block.
|
|
LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
|
|
|
|
private:
|
|
/// Return the context for this config.
|
|
MLIRContext *getContext() const { return config.getContext(); }
|
|
|
|
/// Parse the bytecode version.
|
|
LogicalResult parseVersion(EncodingReader &reader);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Dialect Section
|
|
|
|
LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
|
|
|
|
/// Parse an operation name reference using the given reader.
|
|
FailureOr<OperationName> parseOpName(EncodingReader &reader);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Attribute/Type Section
|
|
|
|
/// Parse an attribute or type using the given reader.
|
|
template <typename T>
|
|
LogicalResult parseAttribute(EncodingReader &reader, T &result) {
|
|
return attrTypeReader.parseAttribute(reader, result);
|
|
}
|
|
LogicalResult parseType(EncodingReader &reader, Type &result) {
|
|
return attrTypeReader.parseType(reader, result);
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Resource Section
|
|
|
|
LogicalResult
|
|
parseResourceSection(Optional<ArrayRef<uint8_t>> resourceData,
|
|
Optional<ArrayRef<uint8_t>> resourceOffsetData);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// IR Section
|
|
|
|
/// This struct represents the current read state of a range of regions. This
|
|
/// struct is used to enable iterative parsing of regions.
|
|
struct RegionReadState {
|
|
RegionReadState(Operation *op, bool isIsolatedFromAbove)
|
|
: RegionReadState(op->getRegions(), isIsolatedFromAbove) {}
|
|
RegionReadState(MutableArrayRef<Region> regions, bool isIsolatedFromAbove)
|
|
: curRegion(regions.begin()), endRegion(regions.end()),
|
|
isIsolatedFromAbove(isIsolatedFromAbove) {}
|
|
|
|
/// The current regions being read.
|
|
MutableArrayRef<Region>::iterator curRegion, endRegion;
|
|
|
|
/// The number of values defined immediately within this region.
|
|
unsigned numValues = 0;
|
|
|
|
/// The current blocks of the region being read.
|
|
SmallVector<Block *> curBlocks;
|
|
Region::iterator curBlock = {};
|
|
|
|
/// The number of operations remaining to be read from the current block
|
|
/// being read.
|
|
uint64_t numOpsRemaining = 0;
|
|
|
|
/// A flag indicating if the regions being read are isolated from above.
|
|
bool isIsolatedFromAbove = false;
|
|
};
|
|
|
|
LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
|
|
LogicalResult parseRegions(EncodingReader &reader,
|
|
std::vector<RegionReadState> ®ionStack,
|
|
RegionReadState &readState);
|
|
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
|
|
RegionReadState &readState,
|
|
bool &isIsolatedFromAbove);
|
|
|
|
LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState);
|
|
LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState);
|
|
LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Value Processing
|
|
|
|
/// Parse an operand reference using the given reader. Returns nullptr in the
|
|
/// case of failure.
|
|
Value parseOperand(EncodingReader &reader);
|
|
|
|
/// Sequentially define the given value range.
|
|
LogicalResult defineValues(EncodingReader &reader, ValueRange values);
|
|
|
|
/// Create a value to use for a forward reference.
|
|
Value createForwardRef();
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Fields
|
|
|
|
/// This class represents a single value scope, in which a value scope is
|
|
/// delimited by isolated from above regions.
|
|
struct ValueScope {
|
|
/// Push a new region state onto this scope, reserving enough values for
|
|
/// those defined within the current region of the provided state.
|
|
void push(RegionReadState &readState) {
|
|
nextValueIDs.push_back(values.size());
|
|
values.resize(values.size() + readState.numValues);
|
|
}
|
|
|
|
/// Pop the values defined for the current region within the provided region
|
|
/// state.
|
|
void pop(RegionReadState &readState) {
|
|
values.resize(values.size() - readState.numValues);
|
|
nextValueIDs.pop_back();
|
|
}
|
|
|
|
/// The set of values defined in this scope.
|
|
std::vector<Value> values;
|
|
|
|
/// The ID for the next defined value for each region current being
|
|
/// processed in this scope.
|
|
SmallVector<unsigned, 4> nextValueIDs;
|
|
};
|
|
|
|
/// The configuration of the parser.
|
|
const ParserConfig &config;
|
|
|
|
/// A location to use when emitting errors.
|
|
Location fileLoc;
|
|
|
|
/// The reader used to process attribute and types within the bytecode.
|
|
AttrTypeReader attrTypeReader;
|
|
|
|
/// The version of the bytecode being read.
|
|
uint64_t version = 0;
|
|
|
|
/// The producer of the bytecode being read.
|
|
StringRef producer;
|
|
|
|
/// The table of IR units referenced within the bytecode file.
|
|
SmallVector<BytecodeDialect> dialects;
|
|
SmallVector<BytecodeOperationName> opNames;
|
|
|
|
/// The reader used to process resources within the bytecode.
|
|
ResourceSectionReader resourceReader;
|
|
|
|
/// The table of strings referenced within the bytecode file.
|
|
StringSectionReader stringReader;
|
|
|
|
/// The current set of available IR value scopes.
|
|
std::vector<ValueScope> valueScopes;
|
|
/// A block containing the set of operations defined to create forward
|
|
/// references.
|
|
Block forwardRefOps;
|
|
/// A block containing previously created, and no longer used, forward
|
|
/// reference operations.
|
|
Block openForwardRefOps;
|
|
/// An operation state used when instantiating forward references.
|
|
OperationState forwardRefOpState;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
|
|
EncodingReader reader(buffer.getBuffer(), fileLoc);
|
|
|
|
// Skip over the bytecode header, this should have already been checked.
|
|
if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
|
|
return failure();
|
|
// Parse the bytecode version and producer.
|
|
if (failed(parseVersion(reader)) ||
|
|
failed(reader.parseNullTerminatedString(producer)))
|
|
return failure();
|
|
|
|
// Add a diagnostic handler that attaches a note that includes the original
|
|
// producer of the bytecode.
|
|
ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
|
|
diag.attachNote() << "in bytecode version " << version
|
|
<< " produced by: " << producer;
|
|
return failure();
|
|
});
|
|
|
|
// Parse the raw data for each of the top-level sections of the bytecode.
|
|
Optional<ArrayRef<uint8_t>> sectionDatas[bytecode::Section::kNumSections];
|
|
while (!reader.empty()) {
|
|
// Read the next section from the bytecode.
|
|
bytecode::Section::ID sectionID;
|
|
ArrayRef<uint8_t> sectionData;
|
|
if (failed(reader.parseSection(sectionID, sectionData)))
|
|
return failure();
|
|
|
|
// Check for duplicate sections, we only expect one instance of each.
|
|
if (sectionDatas[sectionID]) {
|
|
return reader.emitError("duplicate top-level section: ",
|
|
toString(sectionID));
|
|
}
|
|
sectionDatas[sectionID] = sectionData;
|
|
}
|
|
// Check that all of the required sections were found.
|
|
for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
|
|
bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
|
|
if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
|
|
return reader.emitError("missing data for top-level section: ",
|
|
toString(sectionID));
|
|
}
|
|
}
|
|
|
|
// Process the string section first.
|
|
if (failed(stringReader.initialize(
|
|
fileLoc, *sectionDatas[bytecode::Section::kString])))
|
|
return failure();
|
|
|
|
// Process the dialect section.
|
|
if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
|
|
return failure();
|
|
|
|
// Process the resource section if present.
|
|
if (failed(parseResourceSection(
|
|
sectionDatas[bytecode::Section::kResource],
|
|
sectionDatas[bytecode::Section::kResourceOffset])))
|
|
return failure();
|
|
|
|
// Process the attribute and type section.
|
|
if (failed(attrTypeReader.initialize(
|
|
dialects, *sectionDatas[bytecode::Section::kAttrType],
|
|
*sectionDatas[bytecode::Section::kAttrTypeOffset])))
|
|
return failure();
|
|
|
|
// Finally, process the IR section.
|
|
return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
|
|
}
|
|
|
|
LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
|
|
if (failed(reader.parseVarInt(version)))
|
|
return failure();
|
|
|
|
// Validate the bytecode version.
|
|
uint64_t currentVersion = bytecode::kVersion;
|
|
if (version < currentVersion) {
|
|
return reader.emitError("bytecode version ", version,
|
|
" is older than the current version of ",
|
|
currentVersion, ", and upgrade is not supported");
|
|
}
|
|
if (version > currentVersion) {
|
|
return reader.emitError("bytecode version ", version,
|
|
" is newer than the current version ",
|
|
currentVersion);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dialect Section
|
|
|
|
LogicalResult
|
|
BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
|
|
EncodingReader sectionReader(sectionData, fileLoc);
|
|
|
|
// Parse the number of dialects in the section.
|
|
uint64_t numDialects;
|
|
if (failed(sectionReader.parseVarInt(numDialects)))
|
|
return failure();
|
|
dialects.resize(numDialects);
|
|
|
|
// Parse each of the dialects.
|
|
for (uint64_t i = 0; i < numDialects; ++i)
|
|
if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
|
|
return failure();
|
|
|
|
// Parse the operation names, which are grouped by dialect.
|
|
auto parseOpName = [&](BytecodeDialect *dialect) {
|
|
StringRef opName;
|
|
if (failed(stringReader.parseString(sectionReader, opName)))
|
|
return failure();
|
|
opNames.emplace_back(dialect, opName);
|
|
return success();
|
|
};
|
|
while (!sectionReader.empty())
|
|
if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
|
|
BytecodeOperationName *opName = nullptr;
|
|
if (failed(parseEntry(reader, opNames, opName, "operation name")))
|
|
return failure();
|
|
|
|
// Check to see if this operation name has already been resolved. If we
|
|
// haven't, load the dialect and build the operation name.
|
|
if (!opName->opName) {
|
|
if (failed(opName->dialect->load(reader, getContext())))
|
|
return failure();
|
|
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
|
|
getContext());
|
|
}
|
|
return *opName->opName;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Resource Section
|
|
|
|
LogicalResult BytecodeReader::parseResourceSection(
|
|
Optional<ArrayRef<uint8_t>> resourceData,
|
|
Optional<ArrayRef<uint8_t>> resourceOffsetData) {
|
|
// Ensure both sections are either present or not.
|
|
if (resourceData.has_value() != resourceOffsetData.has_value()) {
|
|
if (resourceOffsetData)
|
|
return emitError(fileLoc, "unexpected resource offset section when "
|
|
"resource section is not present");
|
|
return emitError(
|
|
fileLoc,
|
|
"expected resource offset section when resource section is present");
|
|
}
|
|
|
|
// If the resource sections are absent, there is nothing to do.
|
|
if (!resourceData)
|
|
return success();
|
|
|
|
// Initialize the resource reader with the resource sections.
|
|
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
|
|
*resourceData, *resourceOffsetData);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IR Section
|
|
|
|
LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
|
|
Block *block) {
|
|
EncodingReader reader(sectionData, fileLoc);
|
|
|
|
// A stack of operation regions currently being read from the bytecode.
|
|
std::vector<RegionReadState> regionStack;
|
|
|
|
// Parse the top-level block using a temporary module operation.
|
|
OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
|
|
regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true);
|
|
regionStack.back().curBlocks.push_back(moduleOp->getBody());
|
|
regionStack.back().curBlock = regionStack.back().curRegion->begin();
|
|
if (failed(parseBlock(reader, regionStack.back())))
|
|
return failure();
|
|
valueScopes.emplace_back();
|
|
valueScopes.back().push(regionStack.back());
|
|
|
|
// Iteratively parse regions until everything has been resolved.
|
|
while (!regionStack.empty())
|
|
if (failed(parseRegions(reader, regionStack, regionStack.back())))
|
|
return failure();
|
|
if (!forwardRefOps.empty()) {
|
|
return reader.emitError(
|
|
"not all forward unresolved forward operand references");
|
|
}
|
|
|
|
// Verify that the parsed operations are valid.
|
|
if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
|
|
return failure();
|
|
|
|
// Splice the parsed operations over to the provided top-level block.
|
|
auto &parsedOps = moduleOp->getBody()->getOperations();
|
|
auto &destOps = block->getOperations();
|
|
destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
BytecodeReader::parseRegions(EncodingReader &reader,
|
|
std::vector<RegionReadState> ®ionStack,
|
|
RegionReadState &readState) {
|
|
// Read the regions of this operation.
|
|
for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
|
|
// If the current block hasn't been setup yet, parse the header for this
|
|
// region.
|
|
if (readState.curBlock == Region::iterator()) {
|
|
if (failed(parseRegion(reader, readState)))
|
|
return failure();
|
|
|
|
// If the region is empty, there is nothing to more to do.
|
|
if (readState.curRegion->empty())
|
|
continue;
|
|
}
|
|
|
|
// Parse the blocks within the region.
|
|
do {
|
|
while (readState.numOpsRemaining--) {
|
|
// Read in the next operation. We don't read its regions directly, we
|
|
// handle those afterwards as necessary.
|
|
bool isIsolatedFromAbove = false;
|
|
FailureOr<Operation *> op =
|
|
parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
|
|
if (failed(op))
|
|
return failure();
|
|
|
|
// If the op has regions, add it to the stack for processing.
|
|
if ((*op)->getNumRegions()) {
|
|
regionStack.emplace_back(*op, isIsolatedFromAbove);
|
|
|
|
// If the op is isolated from above, push a new value scope.
|
|
if (isIsolatedFromAbove)
|
|
valueScopes.emplace_back();
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// Move to the next block of the region.
|
|
if (++readState.curBlock == readState.curRegion->end())
|
|
break;
|
|
if (failed(parseBlock(reader, readState)))
|
|
return failure();
|
|
} while (true);
|
|
|
|
// Reset the current block and any values reserved for this region.
|
|
readState.curBlock = {};
|
|
valueScopes.back().pop(readState);
|
|
}
|
|
|
|
// When the regions have been fully parsed, pop them off of the read stack. If
|
|
// the regions were isolated from above, we also pop the last value scope.
|
|
if (readState.isIsolatedFromAbove)
|
|
valueScopes.pop_back();
|
|
regionStack.pop_back();
|
|
return success();
|
|
}
|
|
|
|
FailureOr<Operation *>
|
|
BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
|
|
RegionReadState &readState,
|
|
bool &isIsolatedFromAbove) {
|
|
// Parse the name of the operation.
|
|
FailureOr<OperationName> opName = parseOpName(reader);
|
|
if (failed(opName))
|
|
return failure();
|
|
|
|
// Parse the operation mask, which indicates which components of the operation
|
|
// are present.
|
|
uint8_t opMask;
|
|
if (failed(reader.parseByte(opMask)))
|
|
return failure();
|
|
|
|
/// Parse the location.
|
|
LocationAttr opLoc;
|
|
if (failed(parseAttribute(reader, opLoc)))
|
|
return failure();
|
|
|
|
// With the location and name resolved, we can start building the operation
|
|
// state.
|
|
OperationState opState(opLoc, *opName);
|
|
|
|
// Parse the attributes of the operation.
|
|
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
|
|
DictionaryAttr dictAttr;
|
|
if (failed(parseAttribute(reader, dictAttr)))
|
|
return failure();
|
|
opState.attributes = dictAttr;
|
|
}
|
|
|
|
/// Parse the results of the operation.
|
|
if (opMask & bytecode::OpEncodingMask::kHasResults) {
|
|
uint64_t numResults;
|
|
if (failed(reader.parseVarInt(numResults)))
|
|
return failure();
|
|
opState.types.resize(numResults);
|
|
for (int i = 0, e = numResults; i < e; ++i)
|
|
if (failed(parseType(reader, opState.types[i])))
|
|
return failure();
|
|
}
|
|
|
|
/// Parse the operands of the operation.
|
|
if (opMask & bytecode::OpEncodingMask::kHasOperands) {
|
|
uint64_t numOperands;
|
|
if (failed(reader.parseVarInt(numOperands)))
|
|
return failure();
|
|
opState.operands.resize(numOperands);
|
|
for (int i = 0, e = numOperands; i < e; ++i)
|
|
if (!(opState.operands[i] = parseOperand(reader)))
|
|
return failure();
|
|
}
|
|
|
|
/// Parse the successors of the operation.
|
|
if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
|
|
uint64_t numSuccs;
|
|
if (failed(reader.parseVarInt(numSuccs)))
|
|
return failure();
|
|
opState.successors.resize(numSuccs);
|
|
for (int i = 0, e = numSuccs; i < e; ++i) {
|
|
if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i],
|
|
"successor")))
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
/// Parse the regions of the operation.
|
|
if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
|
|
uint64_t numRegions;
|
|
if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
|
|
return failure();
|
|
|
|
opState.regions.reserve(numRegions);
|
|
for (int i = 0, e = numRegions; i < e; ++i)
|
|
opState.regions.push_back(std::make_unique<Region>());
|
|
}
|
|
|
|
// Create the operation at the back of the current block.
|
|
Operation *op = Operation::create(opState);
|
|
readState.curBlock->push_back(op);
|
|
|
|
// If the operation had results, update the value references.
|
|
if (op->getNumResults() && failed(defineValues(reader, op->getResults())))
|
|
return failure();
|
|
|
|
return op;
|
|
}
|
|
|
|
LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
|
|
RegionReadState &readState) {
|
|
// Parse the number of blocks in the region.
|
|
uint64_t numBlocks;
|
|
if (failed(reader.parseVarInt(numBlocks)))
|
|
return failure();
|
|
|
|
// If the region is empty, there is nothing else to do.
|
|
if (numBlocks == 0)
|
|
return success();
|
|
|
|
// Parse the number of values defined in this region.
|
|
uint64_t numValues;
|
|
if (failed(reader.parseVarInt(numValues)))
|
|
return failure();
|
|
readState.numValues = numValues;
|
|
|
|
// Create the blocks within this region. We do this before processing so that
|
|
// we can rely on the blocks existing when creating operations.
|
|
readState.curBlocks.clear();
|
|
readState.curBlocks.reserve(numBlocks);
|
|
for (uint64_t i = 0; i < numBlocks; ++i) {
|
|
readState.curBlocks.push_back(new Block());
|
|
readState.curRegion->push_back(readState.curBlocks.back());
|
|
}
|
|
|
|
// Prepare the current value scope for this region.
|
|
valueScopes.back().push(readState);
|
|
|
|
// Parse the entry block of the region.
|
|
readState.curBlock = readState.curRegion->begin();
|
|
return parseBlock(reader, readState);
|
|
}
|
|
|
|
LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
|
|
RegionReadState &readState) {
|
|
bool hasArgs;
|
|
if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
|
|
return failure();
|
|
|
|
// Parse the arguments of the block.
|
|
if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
|
|
return failure();
|
|
|
|
// We don't parse the operations of the block here, that's done elsewhere.
|
|
return success();
|
|
}
|
|
|
|
LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
|
|
Block *block) {
|
|
// Parse the value ID for the first argument, and the number of arguments.
|
|
uint64_t numArgs;
|
|
if (failed(reader.parseVarInt(numArgs)))
|
|
return failure();
|
|
|
|
SmallVector<Type> argTypes;
|
|
SmallVector<Location> argLocs;
|
|
argTypes.reserve(numArgs);
|
|
argLocs.reserve(numArgs);
|
|
|
|
while (numArgs--) {
|
|
Type argType;
|
|
LocationAttr argLoc;
|
|
if (failed(parseType(reader, argType)) ||
|
|
failed(parseAttribute(reader, argLoc)))
|
|
return failure();
|
|
|
|
argTypes.push_back(argType);
|
|
argLocs.push_back(argLoc);
|
|
}
|
|
block->addArguments(argTypes, argLocs);
|
|
return defineValues(reader, block->getArguments());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Value Processing
|
|
|
|
Value BytecodeReader::parseOperand(EncodingReader &reader) {
|
|
std::vector<Value> &values = valueScopes.back().values;
|
|
Value *value = nullptr;
|
|
if (failed(parseEntry(reader, values, value, "value")))
|
|
return Value();
|
|
|
|
// Create a new forward reference if necessary.
|
|
if (!*value)
|
|
*value = createForwardRef();
|
|
return *value;
|
|
}
|
|
|
|
LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
|
|
ValueRange newValues) {
|
|
ValueScope &valueScope = valueScopes.back();
|
|
std::vector<Value> &values = valueScope.values;
|
|
|
|
unsigned &valueID = valueScope.nextValueIDs.back();
|
|
unsigned valueIDEnd = valueID + newValues.size();
|
|
if (valueIDEnd > values.size()) {
|
|
return reader.emitError(
|
|
"value index range was outside of the expected range for "
|
|
"the parent region, got [",
|
|
valueID, ", ", valueIDEnd, "), but the maximum index was ",
|
|
values.size() - 1);
|
|
}
|
|
|
|
// Assign the values and update any forward references.
|
|
for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
|
|
Value newValue = newValues[i];
|
|
|
|
// Check to see if a definition for this value already exists.
|
|
if (Value oldValue = std::exchange(values[valueID], newValue)) {
|
|
Operation *forwardRefOp = oldValue.getDefiningOp();
|
|
|
|
// Assert that this is a forward reference operation. Given how we compute
|
|
// definition ids (incrementally as we parse), it shouldn't be possible
|
|
// for the value to be defined any other way.
|
|
assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
|
|
"value index was already defined?");
|
|
|
|
oldValue.replaceAllUsesWith(newValue);
|
|
forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
Value BytecodeReader::createForwardRef() {
|
|
// Check for an avaliable existing operation to use. Otherwise, create a new
|
|
// fake operation to use for the reference.
|
|
if (!openForwardRefOps.empty()) {
|
|
Operation *op = &openForwardRefOps.back();
|
|
op->moveBefore(&forwardRefOps, forwardRefOps.end());
|
|
} else {
|
|
forwardRefOps.push_back(Operation::create(forwardRefOpState));
|
|
}
|
|
return forwardRefOps.back().getResult(0);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Entry Points
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
|
|
return buffer.getBuffer().startswith("ML\xefR");
|
|
}
|
|
|
|
LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
|
|
const ParserConfig &config) {
|
|
Location sourceFileLoc =
|
|
FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
|
|
/*line=*/0, /*column=*/0);
|
|
if (!isBytecode(buffer)) {
|
|
return emitError(sourceFileLoc,
|
|
"input buffer is not an MLIR bytecode file");
|
|
}
|
|
|
|
BytecodeReader reader(sourceFileLoc, config);
|
|
return reader.read(buffer, block);
|
|
}
|