llvm-capstone/mlir/test/python/dialects/transform_structured_ext.py
Alex Zinenko ce2e198bc2 [mlir] add decompose and generalize to structured transform ops
These ops complement the tiling/padding transformations by transforming
higher-level named structured operations such as depthwise convolutions into
lower-level and/or generic equivalents that are better handled by some
downstream transformations.

Differential Revision: https://reviews.llvm.org/D126698
2022-06-02 15:25:18 +02:00

141 lines
3.9 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import structured
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
print(module)
return f
@run
def testDecompose():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.DecomposeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testDecompose
# CHECK: transform.sequence
# CHECK: transform.structured.decompose
@run
def testGeneralize():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.GeneralizeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testGeneralize
# CHECK: transform.sequence
# CHECK: transform.structured.generalize
@run
def testInterchange():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.InterchangeOp(
sequence.bodyTarget,
iterator_interchange=[
IntegerAttr.get(IntegerType.get_signless(64), 1), 0
])
transform.YieldOp()
# CHECK-LABEL: TEST: testInterchange
# CHECK: transform.sequence
# CHECK: transform.structured.interchange
# CHECK: iterator_interchange = [1, 0]
@run
def testPad():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.PadOp(
sequence.bodyTarget,
padding_values=[FloatAttr.get_f32(42.0)],
padding_dimensions=[1],
transpose_paddings=[[1, 0]])
transform.YieldOp()
# CHECK-LABEL: TEST: testPad
# CHECK: transform.sequence
# CHECK: transform.structured.pad
# CHECK-DAG: padding_values = [4.200000e+01 : f32]
# CHECK-DAG: padding_dimensions = [1]
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
# CHECK-DAG: hoist_paddings = []
# CHECK-DAG: pack_paddings = []
@run
def testScalarize():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.ScalarizeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testScalarize
# CHECK: transform.structured.scalarize
@run
def testTileCompact():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
# CHECK-DAG: interchange = [0, 1]
# CHECK-DAG: sizes = [4, 8]
@run
def testTileAttributes():
sequence = transform.SequenceOp()
attr = ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
ichange = ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileAttributes
# CHECK: transform.sequence
# CHECK: structured.tile
# CHECK-DAG: interchange = [0, 1]
# CHECK-DAG: sizes = [4, 8]
@run
def testTileZero():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.TileOp(
sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileZero
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
# CHECK-DAG: interchange = [0, 1, 2, 3]
# CHECK-DAG: sizes = [4, 0, 2, 0]
@run
def testVectorize():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
transform.YieldOp()
# CHECK-LABEL: TEST: testVectorize
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK: vectorize_padding = true