# RUN: %PYTHON %s | FileCheck %s from mlir.ir import * from mlir.dialects import transform from mlir.dialects import pdl from mlir.dialects.transform import structured def run(f): with Context(), Location.unknown(): module = Module.create() with InsertionPoint(module.body): print("\nTEST:", f.__name__) f() print(module) return f @run def testDecompose(): sequence = transform.SequenceOp() with InsertionPoint(sequence.body): structured.DecomposeOp(sequence.bodyTarget) transform.YieldOp() # CHECK-LABEL: TEST: testDecompose # CHECK: transform.sequence # CHECK: transform.structured.decompose @run def testGeneralize(): sequence = transform.SequenceOp() with InsertionPoint(sequence.body): structured.GeneralizeOp(sequence.bodyTarget) transform.YieldOp() # CHECK-LABEL: TEST: testGeneralize # CHECK: transform.sequence # CHECK: transform.structured.generalize @run def testInterchange(): sequence = transform.SequenceOp() with InsertionPoint(sequence.body): structured.InterchangeOp( sequence.bodyTarget, iterator_interchange=[ IntegerAttr.get(IntegerType.get_signless(64), 1), 0 ]) transform.YieldOp() # CHECK-LABEL: TEST: testInterchange # CHECK: transform.sequence # CHECK: transform.structured.interchange # CHECK: iterator_interchange = [1, 0] @run def testPad(): sequence = transform.SequenceOp() with InsertionPoint(sequence.body): structured.PadOp( sequence.bodyTarget, padding_values=[FloatAttr.get_f32(42.0)], padding_dimensions=[1], transpose_paddings=[[1, 0]]) transform.YieldOp() # CHECK-LABEL: TEST: testPad # CHECK: transform.sequence # CHECK: transform.structured.pad # CHECK-DAG: padding_values = [4.200000e+01 : f32] # CHECK-DAG: padding_dimensions = [1] # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] # CHECK-DAG: hoist_paddings = [] # CHECK-DAG: pack_paddings = [] @run def testScalarize(): sequence = transform.SequenceOp() with InsertionPoint(sequence.body): structured.ScalarizeOp(sequence.bodyTarget) transform.YieldOp() # CHECK-LABEL: TEST: testScalarize # CHECK: transform.structured.scalarize @run def testTileCompact(): sequence = transform.SequenceOp() with InsertionPoint(sequence.body): structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) transform.YieldOp() # CHECK-LABEL: TEST: testTileCompact # CHECK: transform.sequence # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile # CHECK-DAG: interchange = [0, 1] # CHECK-DAG: sizes = [4, 8] @run def testTileAttributes(): sequence = transform.SequenceOp() attr = ArrayAttr.get( [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) ichange = ArrayAttr.get( [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]]) with InsertionPoint(sequence.body): structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) transform.YieldOp() # CHECK-LABEL: TEST: testTileAttributes # CHECK: transform.sequence # CHECK: structured.tile # CHECK-DAG: interchange = [0, 1] # CHECK-DAG: sizes = [4, 8] @run def testTileZero(): sequence = transform.SequenceOp() with InsertionPoint(sequence.body): structured.TileOp( sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) transform.YieldOp() # CHECK-LABEL: TEST: testTileZero # CHECK: transform.sequence # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile # CHECK-DAG: interchange = [0, 1, 2, 3] # CHECK-DAG: sizes = [4, 0, 2, 0] @run def testVectorize(): sequence = transform.SequenceOp() with InsertionPoint(sequence.body): structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) transform.YieldOp() # CHECK-LABEL: TEST: testVectorize # CHECK: transform.sequence # CHECK: = transform.structured.vectorize # CHECK: vectorize_padding = true