llvm-capstone/mlir/test/lib/IR/TestBytecodeCallbacks.cpp
Adrian Kuegel 50665511c7 [mlir] Apply ClangTidy fixes (NFC)
Remove redundant returns at end of function.
2023-08-04 08:44:20 +02:00

366 lines
16 KiB
C++

//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/raw_ostream.h"
#include <list>
using namespace mlir;
using namespace llvm;
namespace {
class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
public:
TestDialectVersionParser(cl::Option &O)
: cl::parser<test::TestDialectVersion>(O) {}
bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg,
test::TestDialectVersion &v) {
long long major_, minor_;
if (getAsSignedInteger(arg.split(".").first, 10, major_))
return O.error("Invalid argument '" + arg);
if (getAsSignedInteger(arg.split(".").second, 10, minor_))
return O.error("Invalid argument '" + arg);
v = test::TestDialectVersion(major_, minor_);
// Returns true on error.
return false;
}
static void print(raw_ostream &os, const test::TestDialectVersion &v) {
os << v.major_ << "." << v.minor_;
};
};
/// This is a test pass which uses callbacks to encode attributes and types in a
/// custom fashion.
struct TestBytecodeCallbackPass
: public PassWrapper<TestBytecodeCallbackPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass)
StringRef getArgument() const final { return "test-bytecode-callback"; }
StringRef getDescription() const final {
return "Test encoding of a dialect type/attributes with a custom callback";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<test::TestDialect>();
}
TestBytecodeCallbackPass() = default;
TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {}
void runOnOperation() override {
switch (testKind) {
case (0):
return runTest0(getOperation());
case (1):
return runTest1(getOperation());
case (2):
return runTest2(getOperation());
case (3):
return runTest3(getOperation());
case (4):
return runTest4(getOperation());
case (5):
return runTest5(getOperation());
default:
llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
}
}
mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
targetVersion{*this, "test-dialect-version",
llvm::cl::desc(
"Specifies the test dialect version to emit and parse"),
cl::init(test::TestDialectVersion())};
mlir::Pass::Option<int> testKind{
*this, "callback-test",
llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)};
private:
void doRoundtripWithConfigs(Operation *op,
const BytecodeWriterConfig &writeConfig,
const ParserConfig &parseConfig) {
std::string bytecode;
llvm::raw_string_ostream os(bytecode);
if (failed(writeBytecodeToFile(op, os, writeConfig))) {
op->emitError() << "failed to write bytecode\n";
signalPassFailure();
return;
}
auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
if (!newModuleOp.get()) {
op->emitError() << "failed to read bytecode\n";
signalPassFailure();
return;
}
// Print the module to the output stream, so that we can filecheck the
// result.
newModuleOp->print(llvm::outs());
}
// Test0: let's assume that versions older than 2.0 were relying on a special
// integer attribute of a deprecated dialect called "funky". Assume that its
// encoding was made by two varInts, the first was the ID (999) and the second
// contained width and signedness info. We can emit it using a callback
// writing a custom encoding for the "funky" dialect group, and parse it back
// with a custom parser reading the same encoding in the same dialect group.
// Note that the ID 999 does not correspond to a valid integer type in the
// current encodings of builtin types.
void runTest0(Operation *op) {
auto newCtx = std::make_shared<MLIRContext>();
test::TestDialectVersion targetEmissionVersion = targetVersion;
BytecodeWriterConfig writeConfig;
writeConfig.attachTypeCallback(
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
// Do not override anything if version less than 2.0.
if (targetEmissionVersion.major_ >= 2)
return failure();
// For version less than 2.0, override the encoding of IntegerType.
if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
llvm::outs() << "Overriding IntegerType encoding...\n";
dialectGroupName = StringLiteral("funky");
writer.writeVarInt(/* IntegerType */ 999);
writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
return success();
}
return failure();
});
newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
newCtx->allowUnregisteredDialects();
ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
// Get test dialect version from the version map.
auto versionOr = reader.getDialectVersion("test");
assert(succeeded(versionOr) && "expected reader to be able to access "
"the version for test dialect");
const auto *version =
reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
// TODO: once back-deployment is formally supported,
// `targetEmissionVersion` will be encoded in the bytecode file, and
// exposed through the versionMap. Right now though this is not yet
// supported. For the purpose of the test, just use
// `targetEmissionVersion`.
(void)version;
if (targetEmissionVersion.major_ >= 2)
return success();
// `dialectName` is the name of the group we have the opportunity to
// override. In this case, override only the dialect group "funky",
// for which does not exist in memory.
if (dialectName != StringLiteral("funky"))
return success();
uint64_t encoding;
if (failed(reader.readVarInt(encoding)) || encoding != 999)
return success();
llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
uint64_t _widthAndSignedness, width;
IntegerType::SignednessSemantics signedness;
if (succeeded(reader.readVarInt(_widthAndSignedness)) &&
((width = _widthAndSignedness >> 2), true) &&
((signedness = static_cast<IntegerType::SignednessSemantics>(
_widthAndSignedness & 0x3)),
true))
entry = IntegerType::get(reader.getContext(), width, signedness);
// Return nullopt to fall through the rest of the parsing code path.
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
// Test1: When writing bytecode, we override the encoding of TestI32Type with
// the encoding of builtin IntegerType. We can natively parse this without
// the use of a callback, relying on the existing builtin reader mechanism.
void runTest1(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
writeConfig.attachTypeCallback(
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
// Emit TestIntegerType using the builtin dialect encoding.
if (llvm::isa<test::TestI32Type>(entryValue)) {
llvm::outs() << "Overriding TestI32Type encoding...\n";
auto builtinI32Type =
IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
// Specify that this type will need to be written as part of the
// builtin group. This will override the default dialect group of
// the attribute (test).
dialectGroupName = StringLiteral("builtin");
if (succeeded(iface->writeType(builtinI32Type, writer)))
return success();
}
return failure();
});
// We natively parse the attribute as a builtin, so no callback needed.
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
// Test2: When writing bytecode, we write standard builtin IntegerTypes. At
// parsing, we use the encoding of IntegerType to intercept all i32. Then,
// instead of creating i32s, we assemble TestI32Type and return it.
void runTest2(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
if (dialectName != StringLiteral("builtin"))
return success();
Type builtinAttr = iface->readType(reader);
if (auto integerType =
llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
if (integerType.getWidth() == 32 && integerType.isSignless()) {
llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
entry = test::TestI32Type::get(reader.getContext());
}
}
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
// Test3: When writing bytecode, we override the encoding of
// TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
// can natively parse this without the use of a callback, relying on the
// existing builtin reader mechanism.
void runTest3(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
auto i32Type = IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
BytecodeWriterConfig writeConfig;
writeConfig.attachAttributeCallback(
[&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
// Emit TestIntegerType using the builtin dialect encoding.
if (auto testParamAttrs =
llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
// Specify that this attribute will need to be written as part of
// the builtin group. This will override the default dialect group
// of the attribute (test).
dialectGroupName = StringLiteral("builtin");
auto denseAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, i32Type),
{testParamAttrs.getV0(), testParamAttrs.getV1()});
if (succeeded(iface->writeAttribute(denseAttr, writer)))
return success();
}
return failure();
});
// We natively parse the attribute as a builtin, so no callback needed.
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
// Test4: When writing bytecode, we write standard builtin
// DenseIntElementsAttr. At parsing, we use the encoding of
// DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
// <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
// TestAttrParamsAttr and return it.
void runTest4(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
auto i32Type = IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
BytecodeWriterConfig writeConfig;
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Attribute &entry) -> LogicalResult {
// Override only the case where the return type of the builtin reader
// is an i32 and fall through on all the other cases, since we want to
// still use TestDialect normal codepath to parse the other types.
Attribute builtinAttr = iface->readAttribute(reader);
if (auto denseAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
denseAttr.getElementType() == i32Type) {
llvm::outs()
<< "Overriding parsing of TestAttrParamsAttr encoding...\n";
int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
entry =
test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
}
}
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
// Test5: When writing bytecode, we want TestDialect to use nothing else than
// the builtin types and attributes and take full control of the encoding,
// returning failure if any type or attribute is not part of builtin.
void runTest5(Operation *op) {
auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
writeConfig.attachAttributeCallback(
[&](Attribute attr, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
return iface->writeAttribute(attr, writer);
});
writeConfig.attachTypeCallback(
[&](Type type, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
return iface->writeType(type, writer);
});
ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Attribute &entry) -> LogicalResult {
Attribute builtinAttr = iface->readAttribute(reader);
if (!builtinAttr)
return failure();
entry = builtinAttr;
return success();
});
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
Type builtinType = iface->readType(reader);
if (!builtinType) {
return failure();
}
entry = builtinType;
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
};
} // namespace
namespace mlir {
void registerTestBytecodeCallbackPasses() {
PassRegistration<TestBytecodeCallbackPass>();
}
} // namespace mlir