[mlir][transform] Allow passing various library files to interpreter. (#67120)

The transfrom interpreter accepts an argument to a "library" file with
named sequences. This patch exteneds this functionality such that (1)
several such individual files are accepted and (2) folders can be passed
in, in which all `*.mlir` files are loaded.
This commit is contained in:
Ingo Müller 2023-10-06 12:52:49 +02:00 committed by GitHub
parent a233a49b60
commit 6a2071cc6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 200 additions and 53 deletions

View File

@ -33,7 +33,7 @@ namespace detail {
/// Template-free implementation of TransformInterpreterPassBase::initialize.
LogicalResult interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
StringRef transformLibraryFileName,
ArrayRef<std::string> transformLibraryPaths,
std::shared_ptr<OwningOpRef<ModuleOp>> &module,
std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@ -48,7 +48,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
const Pass::Option<std::string> &transformLibraryFileName,
const Pass::ListOption<std::string> &transformLibraryPaths,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName);
@ -62,11 +62,12 @@ LogicalResult interpreterBaseRunOnOperationImpl(
/// transform script. If empty, `debugTransformRootTag` is considered or the
/// pass root operation must contain a single top-level transform op that
/// will be interpreted.
/// - transformLibraryFileName: if non-empty, the module in this file will be
/// - transformLibraryPaths: if non-empty, the modules in these files will be
/// merged into the main transform script run by the interpreter before
/// execution. This allows to provide definitions for external functions
/// used in the main script. Other public symbols in the library module may
/// lead to collisions with public symbols in the main script.
/// used in the main script. Other public symbols in the library modules may
/// lead to collisions with public symbols in the main script and among each
/// other.
/// - debugPayloadRootTag: if non-empty, the value of the attribute named
/// `kTransformDialectTagAttrName` indicating the single op that is
/// considered the payload root of the transform interpreter; otherwise, the
@ -118,16 +119,26 @@ public:
REQUIRE_PASS_OPTION(transformFileName);
REQUIRE_PASS_OPTION(debugPayloadRootTag);
REQUIRE_PASS_OPTION(debugTransformRootTag);
REQUIRE_PASS_OPTION(transformLibraryFileName);
#undef REQUIRE_PASS_OPTION
#define REQUIRE_PASS_LIST_OPTION(NAME) \
static_assert( \
std::is_same_v< \
std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>, \
Pass::ListOption<std::string>>, \
"required " #NAME " string pass option is missing")
REQUIRE_PASS_LIST_OPTION(transformLibraryPaths);
#undef REQUIRE_PASS_LIST_OPTION
StringRef transformFileName =
static_cast<Concrete *>(this)->transformFileName;
StringRef transformLibraryFileName =
static_cast<Concrete *>(this)->transformLibraryFileName;
ArrayRef<std::string> transformLibraryPaths =
static_cast<Concrete *>(this)->transformLibraryPaths;
return detail::interpreterBaseInitializeImpl(
context, transformFileName, transformLibraryFileName,
context, transformFileName, transformLibraryPaths,
sharedTransformModule, transformLibraryModule,
[this](OpBuilder &builder, Location loc) {
return static_cast<Concrete *>(this)->constructTransformModule(
@ -162,7 +173,7 @@ public:
op, pass->getArgument(), sharedTransformModule,
transformLibraryModule,
/*extraMappings=*/{}, options, pass->transformFileName,
pass->transformLibraryFileName, pass->debugPayloadRootTag,
pass->transformLibraryPaths, pass->debugPayloadRootTag,
pass->debugTransformRootTag, binaryName)) ||
failed(pass->runAfterInterpreter(op))) {
return pass->signalPassFailure();

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
@ -194,7 +195,7 @@ saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
Operation *transform, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
const Pass::Option<std::string> &transformLibraryFileName,
const Pass::ListOption<std::string> &transformLibraryPaths,
StringRef binaryName) {
using llvm::sys::fs::TempFile;
Operation *root = getRootOperation(target);
@ -231,7 +232,7 @@ static void performOptionalDebugActions(
Operation *target, Operation *transform, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
const Pass::Option<std::string> &transformLibraryFileName,
const Pass::ListOption<std::string> &transformLibraryPaths,
StringRef binaryName) {
MLIRContext *context = target->getContext();
@ -284,7 +285,7 @@ static void performOptionalDebugActions(
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
saveReproToTempFile(llvm::dbgs(), target, transform, passName,
debugPayloadRootTag, debugTransformRootTag,
transformLibraryFileName, binaryName);
transformLibraryPaths, binaryName);
});
// Remove temporary attributes if they were set.
@ -534,7 +535,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
const Pass::Option<std::string> &transformLibraryFileName,
const Pass::ListOption<std::string> &transformLibraryPaths,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName) {
@ -597,7 +598,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
if (failed(
mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone())))
return failure();
return emitError(transformRoot->getLoc(),
"failed to merge library symbols into transform root");
}
// Step 4
@ -606,7 +608,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
// repro to stderr and/or a file.
performOptionalDebugActions(target, transformRoot, passName,
debugPayloadRootTag, debugTransformRootTag,
transformLibraryFileName, binaryName);
transformLibraryPaths, binaryName);
// Step 5
// ------
@ -615,55 +617,148 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
extraMappings, options);
}
/// Expands the given list of `paths` to a list of `.mlir` files.
///
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
static LogicalResult
expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
SmallVectorImpl<std::string> &fileNames) {
for (const std::string &path : paths) {
auto loc = FileLineColLoc::get(context, path, 0, 0);
if (llvm::sys::fs::is_regular_file(path)) {
LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
fileNames.push_back(path);
continue;
}
if (!llvm::sys::fs::is_directory(path)) {
return emitError(loc)
<< "'" << path << "' is neither a file nor a directory";
}
LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
std::error_code ec;
for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
it != itEnd && !ec; it.increment(ec)) {
const std::string &fileName = it->path();
if (it->type() != llvm::sys::fs::file_type::regular_file) {
LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
<< "'\n");
continue;
}
if (!StringRef(fileName).endswith(".mlir")) {
LLVM_DEBUG(DBGS() << " Skipping '" << fileName
<< "' because it does not end with '.mlir'\n");
continue;
}
LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
fileNames.push_back(fileName);
}
if (ec)
return emitError(loc) << "error while opening files in '" << path
<< "': " << ec.message();
}
return success();
}
LogicalResult transform::detail::interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
StringRef transformLibraryFileName,
ArrayRef<std::string> transformLibraryPaths,
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
moduleBuilder) {
OwningOpRef<ModuleOp> parsedTransformModule;
if (failed(parseTransformModuleFromFile(context, transformFileName,
parsedTransformModule)))
return failure();
if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
auto unknownLoc = UnknownLoc::get(context);
// Parse module from file.
OwningOpRef<ModuleOp> moduleFromFile;
{
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
if (failed(parseTransformModuleFromFile(context, transformFileName,
moduleFromFile)))
return emitError(loc) << "failed to parse transform module";
if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
return emitError(loc) << "failed to verify transform module";
}
// Assemble list of library files.
SmallVector<std::string> libraryFileNames;
if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context,
libraryFileNames)))
return failure();
OwningOpRef<ModuleOp> parsedLibraryModule;
if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
parsedLibraryModule)))
return failure();
if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
return failure();
// Parse modules from library files.
SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
for (const std::string &libraryFileName : libraryFileNames) {
OwningOpRef<ModuleOp> parsedLibrary;
auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
if (failed(parseTransformModuleFromFile(context, libraryFileName,
parsedLibrary)))
return emitError(loc) << "failed to parse transform library module";
if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
return emitError(loc) << "failed to verify transform library module";
parsedLibraries.push_back(std::move(parsedLibrary));
}
if (parsedTransformModule) {
sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
std::move(parsedTransformModule));
// Build shared transform module.
if (moduleFromFile) {
sharedTransformModule =
std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
} else if (moduleBuilder) {
// TODO: better location story.
auto location = UnknownLoc::get(context);
auto loc = FileLineColLoc::get(context, "<shared-transform-module>", 0, 0);
auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
ModuleOp::create(location, "__transform"));
ModuleOp::create(unknownLoc, "__transform"));
OpBuilder b(context);
b.setInsertionPointToEnd(localModule->get().getBody());
if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
if (failed(*result))
return failure();
return (*localModule)->emitError()
<< "failed to create shared transform module";
sharedTransformModule = std::move(localModule);
}
}
if (!parsedLibraryModule || !*parsedLibraryModule)
if (parsedLibraries.empty())
return success();
// Merge parsed libraries into one module.
auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
OwningOpRef<ModuleOp> mergedParsedLibraries =
ModuleOp::create(loc, "__transform");
{
mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
UnitAttr::get(context));
IRRewriter rewriter(context);
// TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
std::move(parsedLibrary))))
return mergedParsedLibraries->emitError()
<< "failed to verify merged transform module";
}
}
// Use parsed libaries to resolve symbols in shared transform module or return
// as separate library module.
if (sharedTransformModule && *sharedTransformModule) {
if (failed(mergeSymbolsInto(sharedTransformModule->get(),
std::move(parsedLibraryModule))))
return failure();
std::move(mergedParsedLibraries))))
return (*sharedTransformModule)->emitError()
<< "failed to merge symbols from library files "
"into shared transform module";
} else {
transformLibraryModule =
std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
std::move(mergedParsedLibraries));
}
return success();
}

View File

@ -2,7 +2,7 @@
// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
// RUN: -test-transform-dialect-erase-schedule -cse \
// RUN: | FileCheck %s

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
// RUN: --verify-diagnostics
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
// RUN: --verify-diagnostics
// The external transform script has a declaration to the named sequence @foo,

View File

@ -0,0 +1,28 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library})" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir,%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir})" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// The definition of the @foo named sequence is provided in another file. It
// will be included because of the pass option. Repeated application of the
// same pass, with or without the library option, should not be a problem.
// Note that the same diagnostic produced twice at the same location only
// needs to be matched once.
// expected-remark @below {{message}}
module attributes {transform.with_named_sequence} {
// CHECK: transform.named_sequence @print_message
transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
transform.named_sequence @reference_other_module(!transform.any_op {transform.readonly})
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
include @reference_other_module failures(propagate) (%arg0) : (!transform.any_op) -> ()
}
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file
// The definition of the @print_message named sequence is provided in another file. It
@ -8,6 +8,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
// expected-error @below {{failed to merge library symbols into transform root}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@ -32,6 +33,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
// expected-error @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@ -47,6 +49,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
// expected-error @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// The definition of the @print_message named sequence is provided in another

View File

@ -0,0 +1,10 @@
// RUN: mlir-opt %s
module attributes {transform.with_named_sequence} {
transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
transform.named_sequence @reference_other_module(%arg0: !transform.any_op) {
transform.include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
transform.yield
}
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-paths=%p/match_matmul_common.mlir' --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-paths=%p/match_matmul_common.mlir' --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(

View File

@ -161,7 +161,7 @@ public:
if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
getOperation(), getArgument(), getSharedTransformModule(),
getTransformLibraryModule(), extraMapping, options,
transformFileName, transformLibraryFileName, debugPayloadRootTag,
transformFileName, transformLibraryPaths, debugPayloadRootTag,
debugTransformRootTag, getBinaryName())))
return signalPassFailure();
}
@ -216,9 +216,9 @@ public:
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply. If empty, "
"select the container of the top-level transform op.")};
Option<std::string> transformLibraryFileName{
*this, "transform-library-file-name", llvm::cl::init(""),
llvm::cl::desc("Optional name of a file with a module that should be "
ListOption<std::string> transformLibraryPaths{
*this, "transform-library-paths", llvm::cl::ZeroOrMore,
llvm::cl::desc("Optional paths to files with modules that should be "
"merged into the transform module to provide the "
"definitions of external named sequences.")};