mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-23 13:50:11 +00:00
[mlir][python bindings] generate all the enums
This PR implements python enum bindings for *all* the enums - this includes `I*Attrs` (including positional/bit) and `Dialect/EnumAttr`.
There are a few parts to this:
1. CMake: a small addition to `declare_mlir_dialect_python_bindings` and `declare_mlir_dialect_extension_python_bindings` to generate the enum, a boolean arg `GEN_ENUM_BINDINGS` to make it opt-in (even though it works for basically all of the dialects), and an optional `GEN_ENUM_BINDINGS_TD_FILE` for handling corner cases.
2. EnumPythonBindingGen.cpp: there are two weedy aspects here that took investigation:
1. If an enum attribute is not a `Dialect/EnumAttr` then the `EnumAttrInfo` record is canonical, as far as both the cases of the enum **and the `AttrDefName`**. On the otherhand, if an enum is a `Dialect/EnumAttr` then the `EnumAttr` record has the correct `AttrDefName` ("load bearing", i.e., populates `ods.ir.AttributeBuilder('<NAME>')`) but its `enum` field contains the cases, which is an instance of `EnumAttrInfo`. The solution is to generate an one enum class for both `Dialect/EnumAttr` and "independent" `EnumAttrInfo` but to make that class interopable with two builder registrations that both do the right thing (see next sub-bullet).
2. Because we don't have a good connection to cpp `EnumAttr`, i.e., only the `enum class` getters are exposed (like `DimensionAttr::get(Dimension value)`), we have to resort to parsing e.g., `Attribute.parse(f'#gpu<dim {x}>')`. This means that the set of supported `assemblyFormat`s (for the enum) is fixed at compile of MLIR (currently 2, the only 2 I saw). There might be some things that could be done here but they would require quite a bit more C API work to support generically (e.g., casting ints to enum cases and binding all the getters or going generically through the `symbolize*` methods, like `symbolizeDimension(uint32_t)` or `symbolizeDimension(StringRef)`).
A few small changes:
1. In addition, since this patch registers default builders for attributes where people might've had their own builders already written, I added a `replace` param to `AttributeBuilder.insert` (`False` by default).
2. `makePythonEnumCaseName` can't handle all the different ways in which people write their enum cases, e.g., `llvm.CConv.Intel_OCL_BI`, which gets turned into `INTEL_O_C_L_B_I` (because `llvm::convertToSnakeFromCamelCase` doesn't look for runs of caps). So I dropped it. On the otherhand regularization does need to done because some enums have `None` as a case (and others might have other python keywords).
3. I turned on `llvm` dialect generation here in order to test `nvvm.WGMMAScaleIn`, which is an enum with [[ d7e26b5620/mlir/include/mlir/IR/EnumAttr.td (L22-L25)
| no explicit discriminator ]] for the `neg` case.
Note, dialects that didn't get a `GEN_ENUM_BINDINGS` don't have any enums to generate.
Let me know if I should add more tests (the three trivial ones I added exercise both the supported `assemblyFormat`s and `replace=True`).
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D157934
This commit is contained in:
parent
be91bd0121
commit
92233062c1
@ -272,6 +272,11 @@ endfunction()
|
||||
# SOURCES: Same as declare_mlir_python_sources().
|
||||
# SOURCES_GLOB: Same as declare_mlir_python_sources().
|
||||
# DEPENDS: Additional dependency targets.
|
||||
# GEN_ENUM_BINDINGS: Generate enum bindings.
|
||||
# GEN_ENUM_BINDINGS_TD_FILE: Optional Tablegen file to generate enums for (relative to ROOT_DIR).
|
||||
# This file is where the *EnumAttrs are defined, not where the *Enums are defined.
|
||||
# **WARNING**: This arg will shortly be removed when the just-below TODO is satisfied. Use at your
|
||||
# risk.
|
||||
#
|
||||
# TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we
|
||||
# use its path to determine where to place the generated python file. If
|
||||
@ -279,9 +284,9 @@ endfunction()
|
||||
# need for the separate "wrapper" .td files
|
||||
function(declare_mlir_dialect_python_bindings)
|
||||
cmake_parse_arguments(ARG
|
||||
""
|
||||
"GEN_ENUM_BINDINGS"
|
||||
"ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME"
|
||||
"SOURCES;SOURCES_GLOB;DEPENDS"
|
||||
"SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE"
|
||||
${ARGN})
|
||||
# Sources.
|
||||
set(_dialect_target "${ARG_ADD_TO_PARENT}.${ARG_DIALECT_NAME}")
|
||||
@ -306,11 +311,22 @@ function(declare_mlir_dialect_python_bindings)
|
||||
)
|
||||
add_public_tablegen_target(${tblgen_target})
|
||||
|
||||
set(_sources ${dialect_filename})
|
||||
if(ARG_GEN_ENUM_BINDINGS OR ARG_GEN_ENUM_BINDINGS_TD_FILE)
|
||||
if(ARG_GEN_ENUM_BINDINGS_TD_FILE)
|
||||
set(td_file "${ARG_ROOT_DIR}/${ARG_GEN_ENUM_BINDINGS_TD_FILE}")
|
||||
set(LLVM_TARGET_DEFINITIONS ${td_file})
|
||||
endif()
|
||||
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
|
||||
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
|
||||
list(APPEND _sources ${enum_filename})
|
||||
endif()
|
||||
|
||||
# Generated.
|
||||
declare_mlir_python_sources("${_dialect_target}.ops_gen"
|
||||
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
ADD_TO_PARENT "${_dialect_target}"
|
||||
SOURCES "${dialect_filename}"
|
||||
SOURCES ${_sources}
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
@ -331,11 +347,16 @@ endfunction()
|
||||
# SOURCES: Same as declare_mlir_python_sources().
|
||||
# SOURCES_GLOB: Same as declare_mlir_python_sources().
|
||||
# DEPENDS: Additional dependency targets.
|
||||
# GEN_ENUM_BINDINGS: Generate enum bindings.
|
||||
# GEN_ENUM_BINDINGS_TD_FILE: Optional Tablegen file to generate enums for (relative to ROOT_DIR).
|
||||
# This file is where the *Attrs are defined, not where the *Enums are defined.
|
||||
# **WARNING**: This arg will shortly be removed when the TODO for
|
||||
# declare_mlir_dialect_python_bindings is satisfied. Use at your risk.
|
||||
function(declare_mlir_dialect_extension_python_bindings)
|
||||
cmake_parse_arguments(ARG
|
||||
""
|
||||
"GEN_ENUM_BINDINGS"
|
||||
"ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME"
|
||||
"SOURCES;SOURCES_GLOB;DEPENDS"
|
||||
"SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE"
|
||||
${ARGN})
|
||||
# Source files.
|
||||
set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}")
|
||||
@ -362,10 +383,21 @@ function(declare_mlir_dialect_extension_python_bindings)
|
||||
add_dependencies(${tblgen_target} ${ARG_DEPENDS})
|
||||
endif()
|
||||
|
||||
set(_sources ${output_filename})
|
||||
if(ARG_GEN_ENUM_BINDINGS OR ARG_GEN_ENUM_BINDINGS_TD_FILE)
|
||||
if(ARG_GEN_ENUM_BINDINGS_TD_FILE)
|
||||
set(td_file "${ARG_ROOT_DIR}/${ARG_GEN_ENUM_BINDINGS_TD_FILE}")
|
||||
set(LLVM_TARGET_DEFINITIONS ${td_file})
|
||||
endif()
|
||||
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
|
||||
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
|
||||
list(APPEND _sources ${enum_filename})
|
||||
endif()
|
||||
|
||||
declare_mlir_python_sources("${_extension_target}.ops_gen"
|
||||
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
ADD_TO_PARENT "${_extension_target}"
|
||||
SOURCES "${output_filename}"
|
||||
SOURCES ${_sources}
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
@ -1,4 +1,4 @@
|
||||
//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===//
|
||||
//===- LinalgEnums.td - Linalg dialect base support ---------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -58,10 +58,11 @@ public:
|
||||
void loadDialectModule(llvm::StringRef dialectNamespace);
|
||||
|
||||
/// Adds a user-friendly Attribute builder.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// Raises an exception if the mapping already exists and replace == false.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerAttributeBuilder(const std::string &attributeKind,
|
||||
pybind11::function pyFunc);
|
||||
pybind11::function pyFunc,
|
||||
bool replace = false);
|
||||
|
||||
/// Adds a user-friendly type caster. Raises an exception if the mapping
|
||||
/// already exists and replace == false. This is intended to be called by
|
||||
|
@ -242,19 +242,23 @@ struct PyAttrBuilderMap {
|
||||
static py::function dundeGetItemNamed(const std::string &attributeKind) {
|
||||
auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
|
||||
if (!builder)
|
||||
throw py::key_error();
|
||||
throw py::key_error(attributeKind);
|
||||
return *builder;
|
||||
}
|
||||
static void dundeSetItemNamed(const std::string &attributeKind,
|
||||
py::function func) {
|
||||
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func));
|
||||
py::function func, bool replace) {
|
||||
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
|
||||
replace);
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
|
||||
.def_static("contains", &PyAttrBuilderMap::dunderContains)
|
||||
.def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
|
||||
.def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed);
|
||||
.def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
|
||||
"attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
|
||||
"Register an attribute builder for building MLIR "
|
||||
"attributes from python values.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -63,11 +63,13 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
|
||||
}
|
||||
|
||||
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
|
||||
py::function pyFunc) {
|
||||
py::function pyFunc, bool replace) {
|
||||
py::object &found = attributeBuilderMap[attributeKind];
|
||||
if (found) {
|
||||
if (found && !found.is_none() && !replace) {
|
||||
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
|
||||
attributeKind + "' is already registered")
|
||||
attributeKind +
|
||||
"' is already registered with func: " +
|
||||
py::str(found).operator std::string())
|
||||
.str());
|
||||
}
|
||||
found = std::move(pyFunc);
|
||||
|
@ -52,7 +52,8 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/AMDGPUOps.td
|
||||
SOURCES
|
||||
dialects/amdgpu.py
|
||||
DIALECT_NAME amdgpu)
|
||||
DIALECT_NAME amdgpu
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -68,7 +69,10 @@ declare_mlir_dialect_python_bindings(
|
||||
SOURCES
|
||||
dialects/bufferization.py
|
||||
dialects/_bufferization_ops_ext.py
|
||||
DIALECT_NAME bufferization)
|
||||
DIALECT_NAME bufferization
|
||||
GEN_ENUM_BINDINGS_TD_FILE
|
||||
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
|
||||
)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -109,7 +113,8 @@ declare_mlir_dialect_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/GPUOps.td
|
||||
SOURCES_GLOB dialects/gpu/*.py
|
||||
DIALECT_NAME gpu)
|
||||
DIALECT_NAME gpu
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -120,7 +125,17 @@ declare_mlir_dialect_python_bindings(
|
||||
SOURCES_GLOB
|
||||
dialects/linalg/*.py
|
||||
DIALECT_NAME linalg
|
||||
DEPENDS LinalgOdsGen)
|
||||
DEPENDS LinalgOdsGen
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/LLVMOps.td
|
||||
SOURCES
|
||||
dialects/llvm.py
|
||||
DIALECT_NAME llvm
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -140,16 +155,10 @@ declare_mlir_dialect_python_bindings(
|
||||
dialects/_transform_ops_ext.py
|
||||
dialects/transform/__init__.py
|
||||
_mlir_libs/_mlir/dialects/transform/__init__.pyi
|
||||
DIALECT_NAME transform)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td")
|
||||
mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings)
|
||||
add_public_tablegen_target(MLIRTransformDialectPyEnumGen)
|
||||
declare_mlir_python_sources(
|
||||
MLIRPythonSources.Dialects.transform.enum_gen
|
||||
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
|
||||
SOURCES "dialects/_transform_enum_gen.py")
|
||||
DIALECT_NAME transform
|
||||
GEN_ENUM_BINDINGS_TD_FILE
|
||||
"../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
|
||||
)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -161,15 +170,6 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME bufferization_transform)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/BufferizationTransformOps.td")
|
||||
mlir_tablegen("dialects/_bufferization_transform_enum_gen.py" -gen-python-enum-bindings)
|
||||
add_public_tablegen_target(MLIRBufferizationTransformDialectPyEnumGen)
|
||||
declare_mlir_python_sources(
|
||||
MLIRPythonSources.Dialects.bufferization_transform.enum_gen
|
||||
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.bufferization_transform
|
||||
SOURCES "dialects/_bufferization_transform_enum_gen.py")
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
@ -208,7 +208,10 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
dialects/_structured_transform_ops_ext.py
|
||||
dialects/transform/structured.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME structured_transform)
|
||||
EXTENSION_NAME structured_transform
|
||||
GEN_ENUM_BINDINGS_TD_FILE
|
||||
"../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
|
||||
)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -227,16 +230,10 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
SOURCES
|
||||
dialects/transform/vector.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME vector_transform)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td")
|
||||
mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings)
|
||||
add_public_tablegen_target(MLIRVectorTransformPyEnumGen)
|
||||
declare_mlir_python_sources(
|
||||
MLIRPythonSources.Dialects.vector_transform.enum_gen
|
||||
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform
|
||||
SOURCES "dialects/_vector_transform_enum_gen.py" )
|
||||
EXTENSION_NAME vector_transform
|
||||
GEN_ENUM_BINDINGS_TD_FILE
|
||||
"../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
|
||||
)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -252,7 +249,8 @@ declare_mlir_dialect_python_bindings(
|
||||
SOURCES
|
||||
dialects/arith.py
|
||||
dialects/_arith_ops_ext.py
|
||||
DIALECT_NAME arith)
|
||||
DIALECT_NAME arith
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -278,7 +276,8 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/NVGPUOps.td
|
||||
SOURCES
|
||||
dialects/nvgpu.py
|
||||
DIALECT_NAME nvgpu)
|
||||
DIALECT_NAME nvgpu
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -286,7 +285,8 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/NVVMOps.td
|
||||
SOURCES
|
||||
dialects/nvvm.py
|
||||
DIALECT_NAME nvvm)
|
||||
DIALECT_NAME nvvm
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -300,6 +300,7 @@ declare_mlir_python_sources(
|
||||
MLIRPythonSources.Dialects.quant
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
GEN_ENUM_BINDINGS
|
||||
SOURCES
|
||||
dialects/quant.py
|
||||
_mlir_libs/_mlir/dialects/quant.pyi)
|
||||
@ -335,7 +336,10 @@ declare_mlir_dialect_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/SparseTensorOps.td
|
||||
SOURCES dialects/sparse_tensor.py
|
||||
DIALECT_NAME sparse_tensor)
|
||||
DIALECT_NAME sparse_tensor
|
||||
GEN_ENUM_BINDINGS_TD_FILE
|
||||
"../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
|
||||
)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -351,14 +355,16 @@ declare_mlir_dialect_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TosaOps.td
|
||||
SOURCES dialects/tosa.py
|
||||
DIALECT_NAME tosa)
|
||||
DIALECT_NAME tosa
|
||||
)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/VectorOps.td
|
||||
SOURCES dialects/vector.py
|
||||
DIALECT_NAME vector)
|
||||
DIALECT_NAME vector
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
################################################################################
|
||||
# Python extensions.
|
||||
|
14
mlir/python/mlir/dialects/LLVMOps.td
Normal file
14
mlir/python/mlir/dialects/LLVMOps.td
Normal file
@ -0,0 +1,14 @@
|
||||
//===-- LlvmOps.td - Entry point for llvm bind ---------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_LLVM_OPS
|
||||
#define PYTHON_BINDINGS_LLVM_OPS
|
||||
|
||||
include "mlir/Dialect/LLVMIR/LLVMOps.td"
|
||||
|
||||
#endif
|
@ -3,3 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._amdgpu_ops_gen import *
|
||||
from ._amdgpu_enum_gen import *
|
||||
|
@ -3,3 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._arith_ops_gen import *
|
||||
from ._arith_enum_gen import *
|
||||
|
@ -3,3 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._bufferization_ops_gen import *
|
||||
from ._bufferization_enum_gen import *
|
||||
|
@ -3,3 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._gpu_ops_gen import *
|
||||
from .._gpu_enum_gen import *
|
||||
|
@ -9,6 +9,7 @@ from ..._mlir_libs._mlirDialectsLinalg import *
|
||||
# definitions following these steps:
|
||||
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
|
||||
from .._linalg_ops_gen import *
|
||||
from .._linalg_enum_gen import *
|
||||
|
||||
# These are the ground truth functions defined as:
|
||||
# ```
|
||||
|
6
mlir/python/mlir/dialects/llvm.py
Normal file
6
mlir/python/mlir/dialects/llvm.py
Normal file
@ -0,0 +1,6 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._llvm_ops_gen import *
|
||||
from ._llvm_enum_gen import *
|
@ -3,3 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._nvgpu_ops_gen import *
|
||||
from ._nvgpu_enum_gen import *
|
||||
|
@ -3,3 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._nvvm_ops_gen import *
|
||||
from ._nvvm_enum_gen import *
|
||||
|
@ -3,5 +3,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._sparse_tensor_ops_gen import *
|
||||
from ._sparse_tensor_enum_gen import *
|
||||
from .._mlir_libs._mlirDialectsSparseTensor import *
|
||||
from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses
|
||||
|
@ -2,5 +2,4 @@
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._bufferization_transform_enum_gen import *
|
||||
from .._bufferization_transform_ops_gen import *
|
||||
|
@ -3,3 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._structured_transform_ops_gen import *
|
||||
from .._structured_transform_enum_gen import *
|
||||
|
@ -3,3 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._vector_ops_gen import *
|
||||
from ._vector_enum_gen import *
|
||||
|
@ -8,9 +8,9 @@ from ._mlir_libs._mlir import register_type_caster
|
||||
|
||||
|
||||
# Convenience decorator for registering user-friendly Attribute builders.
|
||||
def register_attribute_builder(kind):
|
||||
def register_attribute_builder(kind, replace=False):
|
||||
def decorator_builder(func):
|
||||
AttrBuilder.insert(kind, func)
|
||||
AttrBuilder.insert(kind, func, replace=replace)
|
||||
return func
|
||||
|
||||
return decorator_builder
|
||||
|
@ -2,56 +2,108 @@
|
||||
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "TestDialect";
|
||||
let cppNamespace = "::test";
|
||||
}
|
||||
|
||||
// CHECK: Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
// CHECK: from enum import Enum
|
||||
// CHECK: from enum import IntEnum, auto, IntFlag
|
||||
// CHECK: from ._ods_common import _cext as _ods_cext
|
||||
// CHECK: from ..ir import register_attribute_builder
|
||||
// CHECK: _ods_ir = _ods_cext.ir
|
||||
|
||||
def One : I32EnumAttrCase<"CaseOne", 1, "one">;
|
||||
def Two : I32EnumAttrCase<"CaseTwo", 2, "two">;
|
||||
def NegOne : I32EnumAttrCase<"CaseNegOne", -1, "negone">;
|
||||
|
||||
def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>;
|
||||
// CHECK: def _register_attribute_builder(kind):
|
||||
// CHECK: def decorator_builder(func):
|
||||
// CHECK: _ods_ir.AttrBuilder.insert(kind, func)
|
||||
// CHECK: return func
|
||||
// CHECK: return decorator_builder
|
||||
|
||||
// CHECK-LABEL: class MyEnum(Enum):
|
||||
def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>;
|
||||
// CHECK-LABEL: class MyEnum(IntEnum):
|
||||
// CHECK: """An example 32-bit enum"""
|
||||
|
||||
// CHECK: CASE_ONE = 1
|
||||
// CHECK: CASE_TWO = 2
|
||||
// CHECK: CaseOne = 1
|
||||
// CHECK: CaseTwo = 2
|
||||
// CHECK: CaseNegOne = auto()
|
||||
|
||||
// CHECK: def _as_int(self):
|
||||
// CHECK: if self is MyEnum.CASE_ONE:
|
||||
// CHECK: return 1
|
||||
// CHECK: if self is MyEnum.CASE_TWO:
|
||||
// CHECK: return 2
|
||||
// CHECK: assert False, "Unknown MyEnum enum entry."
|
||||
// CHECK: def __str__(self):
|
||||
// CHECK: if self is MyEnum.CaseOne:
|
||||
// CHECK: return "one"
|
||||
// CHECK: if self is MyEnum.CaseTwo:
|
||||
// CHECK: return "two"
|
||||
// CHECK: if self is MyEnum.CaseNegOne:
|
||||
// CHECK: return "negone"
|
||||
// CHECK: raise ValueError("Unknown MyEnum enum entry.")
|
||||
|
||||
// CHECK: @register_attribute_builder("MyEnum")
|
||||
// CHECK: def _myenum(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
def TestMyEnum_Attr : EnumAttr<Test_Dialect, MyEnum, "enum">;
|
||||
|
||||
def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
|
||||
def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">;
|
||||
|
||||
def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>;
|
||||
// CHECK: @_register_attribute_builder("MyEnum")
|
||||
// CHECK: def _my_enum(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int())
|
||||
|
||||
// CHECK-LABEL: class MyEnum64(Enum):
|
||||
// CHECK-LABEL: class MyEnum64(IntEnum):
|
||||
// CHECK: """An example 64-bit enum"""
|
||||
|
||||
// CHECK: CASE_ONE64 = 1
|
||||
// CHECK: CASE_TWO64 = 2
|
||||
// CHECK: CaseOne64 = 1
|
||||
// CHECK: CaseTwo64 = 2
|
||||
|
||||
// CHECK: def _as_int(self):
|
||||
// CHECK: if self is MyEnum64.CASE_ONE64:
|
||||
// CHECK: return 1
|
||||
// CHECK: if self is MyEnum64.CASE_TWO64:
|
||||
// CHECK: return 2
|
||||
// CHECK: assert False, "Unknown MyEnum64 enum entry."
|
||||
// CHECK: def __str__(self):
|
||||
// CHECK: if self is MyEnum64.CaseOne64:
|
||||
// CHECK: return "one"
|
||||
// CHECK: if self is MyEnum64.CaseTwo64:
|
||||
// CHECK: return "two"
|
||||
// CHECK: raise ValueError("Unknown MyEnum64 enum entry.")
|
||||
|
||||
// CHECK: @_register_attribute_builder("MyEnum64")
|
||||
// CHECK: def _my_enum64(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int())
|
||||
// CHECK: @register_attribute_builder("MyEnum64")
|
||||
// CHECK: def _myenum64(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
|
||||
|
||||
def TestBitEnum
|
||||
: I32BitEnumAttr<"TestBitEnum", "", [
|
||||
I32BitEnumAttrCaseBit<"User", 0, "user">,
|
||||
I32BitEnumAttrCaseBit<"Group", 1, "group">,
|
||||
I32BitEnumAttrCaseBit<"Other", 2, "other">,
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let separator = " | ";
|
||||
}
|
||||
|
||||
def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
|
||||
|
||||
// CHECK-LABEL: class TestBitEnum(IntFlag):
|
||||
|
||||
// CHECK: User = 1
|
||||
// CHECK: Group = 2
|
||||
// CHECK: Other = 4
|
||||
|
||||
// CHECK: def __iter__(self):
|
||||
// CHECK: return iter([case for case in type(self) if (self & case) is case])
|
||||
// CHECK: def __len__(self):
|
||||
// CHECK: return bin(self).count("1")
|
||||
|
||||
// CHECK: def __str__(self):
|
||||
// CHECK: if len(self) > 1:
|
||||
// CHECK: return " | ".join(map(str, self))
|
||||
// CHECK: if self is TestBitEnum.User:
|
||||
// CHECK: return "user"
|
||||
// CHECK: if self is TestBitEnum.Group:
|
||||
// CHECK: return "group"
|
||||
// CHECK: if self is TestBitEnum.Other:
|
||||
// CHECK: return "other"
|
||||
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")
|
||||
|
||||
// CHECK: @register_attribute_builder("TestBitEnum")
|
||||
// CHECK: def _testbitenum(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
// CHECK: @register_attribute_builder("TestBitEnum_Attr")
|
||||
// CHECK: def _testbitenum_attr(x, context):
|
||||
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
|
||||
|
||||
// CHECK: @register_attribute_builder("TestMyEnum_Attr")
|
||||
// CHECK: def _testmyenum_attr(x, context):
|
||||
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)
|
||||
|
@ -19,3 +19,17 @@ def testConstantOps():
|
||||
arith.ConstantOp(value=42.42, result=F32Type.get())
|
||||
# CHECK: %cst = arith.constant 4.242000e+01 : f32
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFastMathFlags
|
||||
@run
|
||||
def testFastMathFlags():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
a = arith.ConstantOp(value=42.42, result=F32Type.get())
|
||||
r = arith.AddFOp(
|
||||
a, a, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
|
||||
)
|
||||
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
|
||||
print(r)
|
||||
|
@ -1,22 +1,32 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
import mlir.dialects.gpu
|
||||
import mlir.dialects.gpu as gpu
|
||||
import mlir.dialects.gpu.passes
|
||||
from mlir.passmanager import *
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
|
||||
def testGPUPass():
|
||||
with Context() as context:
|
||||
PassManager.parse("any(gpu-kernel-outlining)")
|
||||
print("SUCCESS")
|
||||
with Context(), Location.unknown():
|
||||
f()
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: testGPUPass
|
||||
# CHECK: SUCCESS
|
||||
run(testGPUPass)
|
||||
@run
|
||||
def testGPUPass():
|
||||
PassManager.parse("any(gpu-kernel-outlining)")
|
||||
print("SUCCESS")
|
||||
|
||||
|
||||
# CHECK-LABEL: testMMAElementWiseAttr
|
||||
@run
|
||||
def testMMAElementWiseAttr():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
gpu.BlockDimOp(gpu.Dimension.y)
|
||||
# CHECK: %0 = gpu.block_dim y
|
||||
print(module)
|
||||
pass
|
||||
|
25
mlir/test/python/dialects/llvm.py
Normal file
25
mlir/test/python/dialects/llvm.py
Normal file
@ -0,0 +1,25 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This is just a smoke test that the dialect is functional.
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import llvm
|
||||
|
||||
|
||||
def constructAndPrintInModule(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
f()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: testSmoke
|
||||
@constructAndPrintInModule
|
||||
def testSmoke():
|
||||
mat64f32_t = Type.parse(
|
||||
"!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
|
||||
)
|
||||
result = llvm.UndefOp(mat64f32_t)
|
||||
# CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
|
@ -3,6 +3,8 @@
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import nvvm
|
||||
from mlir.dialects import llvm
|
||||
from mlir.dialects import func
|
||||
|
||||
|
||||
def constructAndPrintInModule(f):
|
||||
@ -18,5 +20,30 @@ def constructAndPrintInModule(f):
|
||||
# CHECK-LABEL: testSmoke
|
||||
@constructAndPrintInModule
|
||||
def testSmoke():
|
||||
# CHECK: nvvm.cp.async.wait.group 5
|
||||
nvvm.CpAsyncWaitGroupOp(5)
|
||||
i64 = IntegerType.get_signless(64)
|
||||
mat64f32_t = Type.parse(
|
||||
"!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
|
||||
)
|
||||
shape_attr = Attribute.parse("#nvvm.shape<m = 64, n = 32, k = 16>")
|
||||
# CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64)
|
||||
@func.FuncOp.from_py_func(i64, i64)
|
||||
def wgmma_f32_f16_f16(desc_a, desc_b):
|
||||
# CHECK: nvvm.cp.async.wait.group 5
|
||||
nvvm.CpAsyncWaitGroupOp(5)
|
||||
# CHECK: %0 = llvm.mlir.undef : [[MAT_T:.*]]
|
||||
result = llvm.UndefOp(mat64f32_t)
|
||||
# CHECK: %1 = nvvm.wgmma.mma_async %arg0, %arg1, <m = 64, n = 32, k = 16>, D[%0, <zero>], A[<f16>, <neg>, <col>], B[<f16>, <neg>, <col>] : [[MAT_T]] -> [[MAT_T]]
|
||||
result1 = nvvm.WgmmaMmaAsyncOp(
|
||||
results_=mat64f32_t,
|
||||
inouts=result,
|
||||
descriptorA=desc_a,
|
||||
descriptorB=desc_b,
|
||||
shape=shape_attr,
|
||||
typeA=nvvm.WGMMATypes.f16,
|
||||
typeB=nvvm.WGMMATypes.f16,
|
||||
scaleD=nvvm.WGMMAScaleOut.zero,
|
||||
scaleA=nvvm.WGMMAScaleIn.neg,
|
||||
scaleB=nvvm.WGMMAScaleIn.neg,
|
||||
layoutA=nvvm.MMALayout.col,
|
||||
layoutB=nvvm.MMALayout.col,
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ def testTypes():
|
||||
@run
|
||||
def testSequenceOp():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[transform.AnyOpType.get()],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
@ -52,15 +52,15 @@ def testSequenceOp():
|
||||
@run
|
||||
def testNestedSequenceOp():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
nested = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget
|
||||
transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
|
||||
)
|
||||
with InsertionPoint(nested.body):
|
||||
doubly_nested = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[transform.AnyOpType.get()],
|
||||
nested.bodyTarget,
|
||||
)
|
||||
@ -84,7 +84,7 @@ def testNestedSequenceOp():
|
||||
@run
|
||||
def testSequenceOpWithExtras():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
|
||||
@ -99,14 +99,14 @@ def testSequenceOpWithExtras():
|
||||
@run
|
||||
def testNestedSequenceOpWithExtras():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
nested = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
sequence.bodyTarget,
|
||||
sequence.bodyExtraArgs,
|
||||
@ -125,7 +125,7 @@ def testTransformPDLOps():
|
||||
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
|
||||
with InsertionPoint(withPdl.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[transform.AnyOpType.get()],
|
||||
withPdl.bodyTarget,
|
||||
)
|
||||
@ -148,7 +148,7 @@ def testTransformPDLOps():
|
||||
@run
|
||||
def testGetParentOp():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
transform.GetParentOp(
|
||||
@ -164,7 +164,7 @@ def testGetParentOp():
|
||||
@run
|
||||
def testMergeHandlesOp():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
transform.MergeHandlesOp([sequence.bodyTarget])
|
||||
@ -178,7 +178,7 @@ def testMergeHandlesOp():
|
||||
@run
|
||||
def testApplyPatternsOpCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
|
||||
@ -193,7 +193,7 @@ def testApplyPatternsOpCompact():
|
||||
@run
|
||||
def testApplyPatternsOpWithType():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [],
|
||||
transform.FailurePropagationMode.Propagate, [],
|
||||
transform.OperationType.get('test.dummy')
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
@ -211,7 +211,7 @@ def testReplicateOp():
|
||||
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
|
||||
with InsertionPoint(with_pdl.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
|
||||
transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
m1 = transform_pdl.PDLMatchOp(
|
||||
|
@ -3,6 +3,7 @@
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects.transform import bufferization
|
||||
from mlir.dialects.bufferization import LayoutMapOption
|
||||
|
||||
|
||||
def run(f):
|
||||
@ -18,7 +19,7 @@ def run(f):
|
||||
@run
|
||||
def testEmptyTensorToAllocTensorOpCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("tensor.empty"),
|
||||
)
|
||||
@ -33,7 +34,7 @@ def testEmptyTensorToAllocTensorOpCompact():
|
||||
@run
|
||||
def testEmptyTensorToAllocTensorOpTyped():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("tensor.empty"),
|
||||
)
|
||||
@ -51,7 +52,7 @@ def testEmptyTensorToAllocTensorOpTyped():
|
||||
@run
|
||||
def testOneShotBufferizeOpCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
bufferization.OneShotBufferizeOp(sequence.bodyTarget)
|
||||
@ -64,7 +65,7 @@ def testOneShotBufferizeOpCompact():
|
||||
@run
|
||||
def testOneShotBufferizeOpTyped():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
bufferization.OneShotBufferizeOp(
|
||||
@ -80,7 +81,7 @@ def testOneShotBufferizeOpTyped():
|
||||
@run
|
||||
def testOneShotBufferizeOpAttributes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
bufferization.OneShotBufferizeOp(
|
||||
@ -89,7 +90,7 @@ def testOneShotBufferizeOpAttributes():
|
||||
allow_unknown_ops=True,
|
||||
bufferize_function_boundaries=True,
|
||||
create_deallocs=False,
|
||||
function_boundary_type_conversion=bufferization.LayoutMapOption.IDENTITY_LAYOUT_MAP,
|
||||
function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap,
|
||||
memcpy_op="linalg.copy",
|
||||
print_conflicts=True,
|
||||
test_analysis_only=True,
|
||||
|
@ -10,7 +10,7 @@ def run(f):
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ def run(f):
|
||||
@run
|
||||
def getParentLoop():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
loop.GetParentForOp(
|
||||
@ -34,7 +34,7 @@ def getParentLoop():
|
||||
@run
|
||||
def loopOutline():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("scf.for"),
|
||||
)
|
||||
@ -54,7 +54,7 @@ def loopOutline():
|
||||
@run
|
||||
def loopPeel():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("scf.for"),
|
||||
)
|
||||
@ -68,7 +68,7 @@ def loopPeel():
|
||||
@run
|
||||
def loopPipeline():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("scf.for"),
|
||||
)
|
||||
@ -86,7 +86,7 @@ def loopPipeline():
|
||||
@run
|
||||
def loopUnroll():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("scf.for"),
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ def run(f):
|
||||
@run
|
||||
def testMemRefMultiBufferOpCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("memref.alloc"),
|
||||
)
|
||||
@ -35,7 +35,7 @@ def testMemRefMultiBufferOpCompact():
|
||||
@run
|
||||
def testMemRefMultiBufferOpTyped():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("memref.alloc"),
|
||||
)
|
||||
@ -53,7 +53,7 @@ def testMemRefMultiBufferOpTyped():
|
||||
@run
|
||||
def testMemRefMultiBufferOpAttributes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("memref.alloc"),
|
||||
)
|
||||
|
@ -21,7 +21,7 @@ def run(f):
|
||||
@run
|
||||
def testBufferizeToAllocationOpCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.BufferizeToAllocationOp(sequence.bodyTarget)
|
||||
@ -34,7 +34,7 @@ def testBufferizeToAllocationOpCompact():
|
||||
@run
|
||||
def testBufferizeToAllocationOpArgs():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.BufferizeToAllocationOp(
|
||||
@ -57,7 +57,7 @@ def testBufferizeToAllocationOpArgs():
|
||||
@run
|
||||
def testDecompose():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.DecomposeOp(sequence.bodyTarget)
|
||||
@ -70,7 +70,7 @@ def testDecompose():
|
||||
@run
|
||||
def testFuseIntoContainingOpTypes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
|
||||
@ -92,7 +92,7 @@ def testFuseIntoContainingOpTypes():
|
||||
@run
|
||||
def testFuseIntoContainingOpCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
|
||||
@ -109,7 +109,7 @@ def testFuseIntoContainingOpCompact():
|
||||
@run
|
||||
def testGeneralize():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.GeneralizeOp(sequence.bodyTarget)
|
||||
@ -122,7 +122,7 @@ def testGeneralize():
|
||||
@run
|
||||
def testInterchange():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
|
||||
@ -136,7 +136,7 @@ def testInterchange():
|
||||
@run
|
||||
def testMapCopyToThreadsOpCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MapCopyToThreadsOp(
|
||||
@ -153,7 +153,7 @@ def testMapCopyToThreadsOpCompact():
|
||||
@run
|
||||
def testMapCopyToThreadsOpTypes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MapCopyToThreadsOp(
|
||||
@ -174,7 +174,7 @@ def testMapCopyToThreadsOpTypes():
|
||||
@run
|
||||
def testMatchOpNamesString():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy")
|
||||
@ -188,7 +188,7 @@ def testMatchOpNamesString():
|
||||
@run
|
||||
def testMatchOpNamesList():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
|
||||
@ -202,7 +202,7 @@ def testMatchOpNamesList():
|
||||
@run
|
||||
def testMaskedVectorizeStatic():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
|
||||
@ -216,7 +216,7 @@ def testMaskedVectorizeStatic():
|
||||
@run
|
||||
def testMaskedVectorizeArray():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
sizes = Attribute.parse("[16, 4]")
|
||||
@ -231,7 +231,7 @@ def testMaskedVectorizeArray():
|
||||
@run
|
||||
def testMaskedVectorizeMixed():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
|
||||
@ -248,7 +248,7 @@ def testMaskedVectorizeMixed():
|
||||
@run
|
||||
def testMaskedVectorizeScalable():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
|
||||
@ -265,7 +265,7 @@ def testMaskedVectorizeScalable():
|
||||
@run
|
||||
def testMaskedVectorizeArgs():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MaskedVectorizeOp(
|
||||
@ -281,7 +281,7 @@ def testMaskedVectorizeArgs():
|
||||
@run
|
||||
def testMatchOpNamesTyped():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MatchOp.match_op_names(
|
||||
@ -299,7 +299,7 @@ def testMatchOpNamesTyped():
|
||||
@run
|
||||
def testMultitileSizes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MultiTileSizesOp(
|
||||
@ -316,7 +316,7 @@ def testMultitileSizes():
|
||||
@run
|
||||
def testPad():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.PadOp(
|
||||
@ -343,7 +343,7 @@ def testPad():
|
||||
@run
|
||||
def testScalarize():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.ScalarizeOp(sequence.bodyTarget)
|
||||
@ -355,7 +355,7 @@ def testScalarize():
|
||||
@run
|
||||
def testSplit():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
|
||||
@ -369,7 +369,7 @@ def testSplit():
|
||||
@run
|
||||
def testTileCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
|
||||
@ -383,7 +383,7 @@ def testTileCompact():
|
||||
@run
|
||||
def testTileAttributes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
attr = DenseI64ArrayAttr.get([4, 8])
|
||||
ichange = DenseI64ArrayAttr.get([0, 1])
|
||||
@ -399,7 +399,7 @@ def testTileAttributes():
|
||||
@run
|
||||
def testTileZero():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileOp(
|
||||
@ -417,7 +417,7 @@ def testTileDynamic():
|
||||
with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
|
||||
with InsertionPoint(with_pdl.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
|
||||
transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
m1 = transform_pdl.PDLMatchOp(
|
||||
@ -437,7 +437,7 @@ def testTileDynamic():
|
||||
@run
|
||||
def testTileExplicitLoopTypeSingle():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileOp(
|
||||
@ -452,7 +452,7 @@ def testTileExplicitLoopTypeSingle():
|
||||
@run
|
||||
def testTileExplicitLoopTypeAll():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
types = [
|
||||
transform.OperationType.get(x)
|
||||
@ -470,7 +470,7 @@ def testTileExplicitLoopTypeAll():
|
||||
@run
|
||||
def testTileToForallCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("linalg.matmul"),
|
||||
)
|
||||
@ -486,7 +486,7 @@ def testTileToForallCompact():
|
||||
@run
|
||||
def testTileToForallLoopsAndTileOpTypes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileToForallOp(
|
||||
@ -505,7 +505,7 @@ def testTileToForallLoopsAndTileOpTypes():
|
||||
@run
|
||||
def testTileToForallTileSizes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4])
|
||||
@ -518,7 +518,7 @@ def testTileToForallTileSizes():
|
||||
@run
|
||||
def testTileToForallMixedDynamic():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
|
||||
@ -532,7 +532,7 @@ def testTileToForallMixedDynamic():
|
||||
@run
|
||||
def testTileToForallPackedDynamic():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
|
||||
@ -546,7 +546,7 @@ def testTileToForallPackedDynamic():
|
||||
@run
|
||||
def testTileToForallMapping():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
|
||||
@ -562,7 +562,7 @@ def testTileToForallMapping():
|
||||
@run
|
||||
def testVectorize():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
|
||||
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
|
||||
@ -571,3 +571,53 @@ def testVectorize():
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: = transform.structured.vectorize
|
||||
# CHECK: {vectorize_padding}
|
||||
|
||||
|
||||
@run
|
||||
def testMatchInterfaceEnum():
|
||||
names = ArrayAttr.get([StringAttr.get("test.dummy")])
|
||||
result_type = transform.AnyOpType.get()
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
fused = structured.MatchOp.__base__(
|
||||
result_type,
|
||||
sequence.bodyTarget,
|
||||
ops=names,
|
||||
interface=structured.MatchInterfaceEnum.LinalgOp,
|
||||
)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testMatchInterfaceEnum
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: = transform.structured.match
|
||||
# CHECK: interface{LinalgOp}
|
||||
|
||||
|
||||
@run
|
||||
def testMatchInterfaceEnumReplaceAttributeBuilder():
|
||||
@register_attribute_builder("MatchInterfaceEnum", replace=True)
|
||||
def match_interface_enum(x, context):
|
||||
if x == "LinalgOp":
|
||||
y = 0
|
||||
elif x == "TilingInterface":
|
||||
y = 1
|
||||
return IntegerAttr.get(IntegerType.get_signless(32, context=context), y)
|
||||
|
||||
names = ArrayAttr.get([StringAttr.get("test.dummy")])
|
||||
result_type = transform.AnyOpType.get()
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
fused = structured.MatchOp.__base__(
|
||||
result_type,
|
||||
sequence.bodyTarget,
|
||||
ops=names,
|
||||
interface="TilingInterface",
|
||||
)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: = transform.structured.match
|
||||
# CHECK: interface{TilingInterface}
|
||||
|
@ -11,7 +11,7 @@ def run(f):
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
|
@ -10,7 +10,7 @@ def run_apply_patterns(f):
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE,
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
@ -72,12 +72,12 @@ def enum_configurable_patterns():
|
||||
# CHECK: transform.apply_patterns.vector.lower_contraction
|
||||
# CHECK-SAME: lowering_strategy = matmulintrinsics
|
||||
vector.ApplyLowerContractionPatternsOp(
|
||||
lowering_strategy=vector.VectorContractLowering.MATMUL
|
||||
lowering_strategy=vector.VectorContractLowering.Matmul
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.lower_contraction
|
||||
# CHECK-SAME: lowering_strategy = parallelarith
|
||||
vector.ApplyLowerContractionPatternsOp(
|
||||
lowering_strategy=vector.VectorContractLowering.PARALLEL_ARITH
|
||||
lowering_strategy=vector.VectorContractLowering.ParallelArith
|
||||
)
|
||||
|
||||
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
|
||||
@ -85,12 +85,12 @@ def enum_configurable_patterns():
|
||||
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
|
||||
# This is the default mode, not printed.
|
||||
vector.ApplyLowerMultiReductionPatternsOp(
|
||||
lowering_strategy=vector.VectorMultiReductionLowering.INNER_PARALLEL
|
||||
lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
|
||||
# CHECK-SAME: lowering_strategy = innerreduction
|
||||
vector.ApplyLowerMultiReductionPatternsOp(
|
||||
lowering_strategy=vector.VectorMultiReductionLowering.INNER_REDUCTION
|
||||
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
|
||||
)
|
||||
|
||||
# CHECK: transform.apply_patterns.vector.lower_transpose
|
||||
@ -101,31 +101,31 @@ def enum_configurable_patterns():
|
||||
# CHECK-SAME: lowering_strategy = eltwise
|
||||
# CHECK-SAME: avx2_lowering_strategy = false
|
||||
vector.ApplyLowerTransposePatternsOp(
|
||||
lowering_strategy=vector.VectorTransposeLowering.ELT_WISE
|
||||
lowering_strategy=vector.VectorTransposeLowering.EltWise
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.lower_transpose
|
||||
# CHECK-SAME: lowering_strategy = flat_transpose
|
||||
# CHECK-SAME: avx2_lowering_strategy = false
|
||||
vector.ApplyLowerTransposePatternsOp(
|
||||
lowering_strategy=vector.VectorTransposeLowering.FLAT
|
||||
lowering_strategy=vector.VectorTransposeLowering.Flat
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.lower_transpose
|
||||
# CHECK-SAME: lowering_strategy = shuffle_1d
|
||||
# CHECK-SAME: avx2_lowering_strategy = false
|
||||
vector.ApplyLowerTransposePatternsOp(
|
||||
lowering_strategy=vector.VectorTransposeLowering.SHUFFLE1_D
|
||||
lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.lower_transpose
|
||||
# CHECK-SAME: lowering_strategy = shuffle_16x16
|
||||
# CHECK-SAME: avx2_lowering_strategy = false
|
||||
vector.ApplyLowerTransposePatternsOp(
|
||||
lowering_strategy=vector.VectorTransposeLowering.SHUFFLE16X16
|
||||
lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.lower_transpose
|
||||
# CHECK-SAME: lowering_strategy = flat_transpose
|
||||
# CHECK-SAME: avx2_lowering_strategy = true
|
||||
vector.ApplyLowerTransposePatternsOp(
|
||||
lowering_strategy=vector.VectorTransposeLowering.FLAT,
|
||||
lowering_strategy=vector.VectorTransposeLowering.Flat,
|
||||
avx2_lowering_strategy=True,
|
||||
)
|
||||
|
||||
@ -134,20 +134,20 @@ def enum_configurable_patterns():
|
||||
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
||||
# CHECK-SAME: split_transfer_strategy = none
|
||||
vector.ApplySplitTransferFullPartialPatternsOp(
|
||||
split_transfer_strategy=vector.VectorTransferSplit.NONE
|
||||
split_transfer_strategy=vector.VectorTransferSplit.None_
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
||||
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
|
||||
vector.ApplySplitTransferFullPartialPatternsOp(
|
||||
split_transfer_strategy=vector.VectorTransferSplit.VECTOR_TRANSFER
|
||||
split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
||||
# This is the default mode, not printed.
|
||||
vector.ApplySplitTransferFullPartialPatternsOp(
|
||||
split_transfer_strategy=vector.VectorTransferSplit.LINALG_COPY
|
||||
split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
|
||||
)
|
||||
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
||||
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
|
||||
vector.ApplySplitTransferFullPartialPatternsOp(
|
||||
split_transfer_strategy=vector.VectorTransferSplit.FORCE_IN_BOUNDS
|
||||
split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
|
||||
)
|
||||
|
@ -64,3 +64,21 @@ def testTransferReadOp():
|
||||
# CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
|
||||
# CHECK-NOT: %[[MASK]]
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBitEnumCombiningKind
|
||||
@run
|
||||
def testBitEnumCombiningKind():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
f32 = F32Type.get()
|
||||
vector_type = VectorType.get([16], f32)
|
||||
|
||||
@func.FuncOp.from_py_func(vector_type)
|
||||
def reduction(arg):
|
||||
v = vector.ReductionOp(f32, vector.CombiningKind.ADD, arg)
|
||||
return v
|
||||
|
||||
# CHECK: func.func @reduction(%[[VEC:.*]]: vector<16xf32>) -> f32 {
|
||||
# CHECK: %0 = vector.reduction <add>, %[[VEC]] : vector<16xf32> into f32
|
||||
print(module)
|
||||
|
@ -10,10 +10,12 @@
|
||||
// generate the corresponding Python binding classes.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "OpGenHelpers.h"
|
||||
|
||||
#include "mlir/TableGen/AttrOrTypeDef.h"
|
||||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/Dialect.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
@ -24,48 +26,61 @@ using namespace mlir::tblgen;
|
||||
constexpr const char *fileHeader = R"Py(
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from enum import Enum
|
||||
from enum import IntEnum, auto, IntFlag
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ..ir import register_attribute_builder
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
||||
# Convenience decorator for registering user-friendly Attribute builders.
|
||||
def _register_attribute_builder(kind):
|
||||
def decorator_builder(func):
|
||||
_ods_ir.AttrBuilder.insert(kind, func)
|
||||
return func
|
||||
|
||||
return decorator_builder
|
||||
|
||||
)Py";
|
||||
|
||||
/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
|
||||
static std::string makePythonEnumCaseName(StringRef name) {
|
||||
return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper();
|
||||
if (isPythonReserved(name.str()))
|
||||
return (name + "_").str();
|
||||
return name.str();
|
||||
}
|
||||
|
||||
/// Emits the Python class for the given enum.
|
||||
static void emitEnumClass(StringRef enumName, StringRef description,
|
||||
ArrayRef<EnumAttrCase> cases, raw_ostream &os) {
|
||||
os << llvm::formatv("class {0}(Enum):\n", enumName);
|
||||
if (!description.empty())
|
||||
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description);
|
||||
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
|
||||
os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
|
||||
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
|
||||
if (!enumAttr.getSummary().empty())
|
||||
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
|
||||
os << "\n";
|
||||
|
||||
for (const EnumAttrCase &enumCase : cases) {
|
||||
os << llvm::formatv(" {0} = {1}\n",
|
||||
makePythonEnumCaseName(enumCase.getSymbol()),
|
||||
enumCase.getValue());
|
||||
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
|
||||
os << llvm::formatv(
|
||||
" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
|
||||
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
|
||||
: "auto()");
|
||||
}
|
||||
|
||||
os << "\n";
|
||||
os << llvm::formatv(" def _as_int(self):\n");
|
||||
for (const EnumAttrCase &enumCase : cases) {
|
||||
os << llvm::formatv(" if self is {0}.{1}:\n", enumName,
|
||||
|
||||
if (enumAttr.isBitEnum()) {
|
||||
os << llvm::formatv(" def __iter__(self):\n"
|
||||
" return iter([case for case in type(self) if "
|
||||
"(self & case) is case])\n");
|
||||
os << llvm::formatv(" def __len__(self):\n"
|
||||
" return bin(self).count(\"1\")\n");
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
os << llvm::formatv(" def __str__(self):\n");
|
||||
if (enumAttr.isBitEnum())
|
||||
os << llvm::formatv(" if len(self) > 1:\n"
|
||||
" return \"{0}\".join(map(str, self))\n",
|
||||
enumAttr.getDef().getValueAsString("separator"));
|
||||
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
|
||||
os << llvm::formatv(" if self is {0}.{1}:\n",
|
||||
enumAttr.getEnumClassName(),
|
||||
makePythonEnumCaseName(enumCase.getSymbol()));
|
||||
os << llvm::formatv(" return {0}\n", enumCase.getValue());
|
||||
os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr());
|
||||
}
|
||||
os << llvm::formatv(" assert False, \"Unknown {0} enum entry.\"\n\n\n",
|
||||
enumName);
|
||||
os << llvm::formatv(
|
||||
" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
|
||||
enumAttr.getEnumClassName());
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
/// Attempts to extract the bitwidth B from string "uintB_t" describing the
|
||||
@ -90,36 +105,68 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
|
||||
return true;
|
||||
}
|
||||
|
||||
os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n",
|
||||
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
|
||||
enumAttr.getAttrDefName());
|
||||
os << llvm::formatv(
|
||||
"def _{0}(x, context):\n",
|
||||
llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName()));
|
||||
os << llvm::formatv("def _{0}(x, context):\n",
|
||||
enumAttr.getAttrDefName().lower());
|
||||
os << llvm::formatv(
|
||||
" return "
|
||||
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
|
||||
"context=context), x._as_int())\n\n",
|
||||
"context=context), int(x))\n\n",
|
||||
bitwidth);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Emits an attribute builder for the given dialect enum attribute to support
|
||||
/// automatic conversion between enum values and attributes in Python. Returns
|
||||
/// `false` on success, `true` on failure.
|
||||
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
|
||||
StringRef formatString,
|
||||
raw_ostream &os) {
|
||||
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
|
||||
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
|
||||
os << llvm::formatv(" return "
|
||||
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
|
||||
formatString);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Emits Python bindings for all enums in the record keeper. Returns
|
||||
/// `false` on success, `true` on failure.
|
||||
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
os << fileHeader;
|
||||
std::vector<llvm::Record *> defs =
|
||||
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
|
||||
for (const llvm::Record *def : defs) {
|
||||
EnumAttr enumAttr(*def);
|
||||
if (enumAttr.isBitEnum()) {
|
||||
llvm::errs() << "bit enums not supported\n";
|
||||
return true;
|
||||
}
|
||||
emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(),
|
||||
enumAttr.getAllCases(), os);
|
||||
for (auto &it :
|
||||
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
|
||||
EnumAttr enumAttr(*it);
|
||||
emitEnumClass(enumAttr, os);
|
||||
emitAttributeBuilder(enumAttr, os);
|
||||
}
|
||||
for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
|
||||
AttrOrTypeDef attr(&*it);
|
||||
if (!attr.getMnemonic()) {
|
||||
llvm::errs() << "enum case " << attr
|
||||
<< " needs mnemonic for python enum bindings generation";
|
||||
return true;
|
||||
}
|
||||
StringRef mnemonic = attr.getMnemonic().value();
|
||||
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
|
||||
StringRef dialect = attr.getDialect().getName();
|
||||
if (assemblyFormat == "`<` $value `>`") {
|
||||
emitDialectEnumAttributeBuilder(
|
||||
attr.getName(),
|
||||
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
|
||||
} else if (assemblyFormat == "$value") {
|
||||
emitDialectEnumAttributeBuilder(
|
||||
attr.getName(),
|
||||
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
|
||||
} else {
|
||||
llvm::errs()
|
||||
<< "unsupported assembly format for python enum bindings generation";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "OpGenHelpers.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
@ -63,3 +64,19 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
|
||||
|
||||
return defs;
|
||||
}
|
||||
|
||||
bool mlir::tblgen::isPythonReserved(StringRef str) {
|
||||
static llvm::StringSet<> reserved({
|
||||
"False", "None", "True", "and", "as", "assert", "async",
|
||||
"await", "break", "class", "continue", "def", "del", "elif",
|
||||
"else", "except", "finally", "for", "from", "global", "if",
|
||||
"import", "in", "is", "lambda", "nonlocal", "not", "or",
|
||||
"pass", "raise", "return", "try", "while", "with", "yield",
|
||||
});
|
||||
// These aren't Python keywords but builtin functions that shouldn't/can't be
|
||||
// shadowed.
|
||||
reserved.insert("callable");
|
||||
reserved.insert("issubclass");
|
||||
reserved.insert("type");
|
||||
return reserved.contains(str);
|
||||
}
|
@ -24,6 +24,10 @@ namespace tblgen {
|
||||
std::vector<llvm::Record *>
|
||||
getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
|
||||
|
||||
/// Checks whether `str` is a Python keyword or would shadow builtin function.
|
||||
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
|
||||
bool isPythonReserved(llvm::StringRef str);
|
||||
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -11,6 +11,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "OpGenHelpers.h"
|
||||
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
@ -278,18 +280,6 @@ static llvm::cl::opt<std::string> clDialectExtensionName(
|
||||
|
||||
using AttributeClasses = DenseMap<StringRef, StringRef>;
|
||||
|
||||
/// Checks whether `str` is a Python keyword or would shadow builtin function.
|
||||
static bool isPythonReserved(StringRef str) {
|
||||
static llvm::StringSet<> reserved(
|
||||
{"and", "as", "assert", "break", "callable", "class",
|
||||
"continue", "def", "del", "elif", "else", "except",
|
||||
"finally", "for", "from", "global", "if", "import",
|
||||
"in", "is", "lambda", "nonlocal", "not", "or",
|
||||
"pass", "raise", "return", "issubclass", "try", "type",
|
||||
"while", "with", "yield"});
|
||||
return reserved.contains(str);
|
||||
}
|
||||
|
||||
/// Checks whether `str` would shadow a generated variable or attribute
|
||||
/// part of the OpView API.
|
||||
static bool isODSReserved(StringRef str) {
|
||||
|
Loading…
Reference in New Issue
Block a user