llvm-capstone/mlir/test/python/dialects/transform_structured_ext.py
Alex Zinenko 88c5027b93 [mlir] make multi-size tiling use transform parameters
Use the recently introduced transform dialect parameter mechanism to
perform controllable multi-size tiling with sizes computed at the
transformation time rather than at runtime.

This requires to generalize tile and split structured transform
operations to work with any transform dialect handle types, which is
desirable in itself to avoid unchecked overuse of PDL OperationType.

Reviewed By: shabalin

Differential Revision: https://reviews.llvm.org/D140980
2023-01-19 10:19:37 +00:00

211 lines
7.5 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import structured
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
print(module)
return f
@run
def testDecompose():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.DecomposeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testDecompose
# CHECK: transform.sequence
# CHECK: transform.structured.decompose
@run
def testGeneralize():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.GeneralizeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testGeneralize
# CHECK: transform.sequence
# CHECK: transform.structured.generalize
@run
def testInterchange():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.InterchangeOp(
sequence.bodyTarget,
iterator_interchange=[1, 0])
transform.YieldOp()
# CHECK-LABEL: TEST: testInterchange
# CHECK: transform.sequence
# CHECK: transform.structured.interchange
# CHECK: iterator_interchange = [1, 0]
@run
def testMultitileSizes():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.MultiTileSizesOp(pdl.OperationType.get(),
sequence.bodyTarget,
dimension=1,
target_size=42)
transform.YieldOp()
# CHECK-LABEL: TEST: testMultitileSizes
# CHECK: transform.sequence
# CHECK: transform.structured.multitile_sizes
# CHECK-DAG: dimension = 1
# CHECK-DAG: target_size = 42
@run
def testPad():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.PadOp(
sequence.bodyTarget,
padding_values=[FloatAttr.get_f32(42.0)],
padding_dimensions=[1],
transpose_paddings=[[1, 0]])
transform.YieldOp()
# CHECK-LABEL: TEST: testPad
# CHECK: transform.sequence
# CHECK: transform.structured.pad
# CHECK-DAG: padding_values = [4.200000e+01 : f32]
# CHECK-DAG: padding_dimensions = [1]
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
# (hoist_paddings and pack_paddings have default values)
@run
def testScalarize():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.ScalarizeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testScalarize
# CHECK: transform.structured.scalarize
@run
def testSplit():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
structured.SplitOp(
split.results[0], dimension=3, split_point=split.results[1])
transform.YieldOp()
# CHECK-LABEL: TEST: testSplit
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
@run
def testTileCompact():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget,
sizes=[4, 8],
interchange=[0, 1])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
# CHECK: interchange = [0, 1]
@run
def testTileAttributes():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget,
sizes=attr,
interchange=ichange)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileAttributes
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
# CHECK: interchange = [0, 1]
@run
def testTileZero():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget,
sizes=[4, 0, 2, 0],
interchange=[0, 1, 2, 3])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileZero
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
# CHECK: interchange = [0, 1, 2, 3]
@run
def testTileDynamic():
with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [],
with_pdl.bodyTarget)
with InsertionPoint(sequence.body):
m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
structured.TileOp(sequence.bodyTarget,
sizes=[m1, 3, m2, 0])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileDynamic
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
@run
def testTileExplicitLoopTypeSingle():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.AnyOpType.get())
with InsertionPoint(sequence.body):
structured.TileOp(transform.OperationType.get("scf.for"),
sequence.bodyTarget,
sizes=[2, 3, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
# CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
# CHECK-COUNT-3: !transform.op<"scf.for">
@run
def testTileExplicitLoopTypeAll():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.AnyOpType.get())
types = [
transform.OperationType.get(x)
for x in ["scf.for", "scf.parallel", "scf.foreach_thread"]
]
with InsertionPoint(sequence.body):
structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
# CHECK: = transform.structured.tile
# CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
# CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.foreach_thread">
@run
def testVectorize():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
transform.YieldOp()
# CHECK-LABEL: TEST: testVectorize
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK: {vectorize_padding}