# RUN: %PYTHON %s | FileCheck %s import gc from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 def add_dummy_value(): return Operation.create( "custom.value", results=[IntegerType.get_signless(32)] ).result def testOdsBuildDefaultImplicitRegions(): class TestFixedRegionsOp(OpView): OPERATION_NAME = "custom.test_op" _ODS_REGIONS = (2, True) class TestVariadicRegionsOp(OpView): OPERATION_NAME = "custom.test_any_regions_op" _ODS_REGIONS = (2, False) with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True m = Module.create() with InsertionPoint(m.body): op = TestFixedRegionsOp.build_generic(results=[], operands=[]) # CHECK: NUM_REGIONS: 2 print(f"NUM_REGIONS: {len(op.regions)}") # Including a regions= that matches should be fine. op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2) print(f"NUM_REGIONS: {len(op.regions)}") # Reject greater than. try: op = TestFixedRegionsOp.build_generic( results=[], operands=[], regions=3 ) except ValueError as e: # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3 print(f"ERROR:{e}") # Reject less than. try: op = TestFixedRegionsOp.build_generic( results=[], operands=[], regions=1 ) except ValueError as e: # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1 print(f"ERROR:{e}") # If no regions specified for a variadic region op, build the minimum. op = TestVariadicRegionsOp.build_generic(results=[], operands=[]) # CHECK: DEFAULT_NUM_REGIONS: 2 print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}") # Should also accept an explicit regions= that matches the minimum. op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2) # CHECK: EQ_NUM_REGIONS: 2 print(f"EQ_NUM_REGIONS: {len(op.regions)}") # And accept greater than minimum. # Should also accept an explicit regions= that matches the minimum. op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3) # CHECK: GT_NUM_REGIONS: 3 print(f"GT_NUM_REGIONS: {len(op.regions)}") # Should reject less than minimum. try: op = TestVariadicRegionsOp.build_generic( results=[], operands=[], regions=1 ) except ValueError as e: # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1 print(f"ERROR:{e}") run(testOdsBuildDefaultImplicitRegions) def testOdsBuildDefaultNonVariadic(): class TestOp(OpView): OPERATION_NAME = "custom.test_op" with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True m = Module.create() with InsertionPoint(m.body): v0 = add_dummy_value() v1 = add_dummy_value() t0 = IntegerType.get_signless(8) t1 = IntegerType.get_signless(16) op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1]) # CHECK: %[[V0:.+]] = "custom.value" # CHECK: %[[V1:.+]] = "custom.value" # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) # CHECK-NOT: operandSegmentSizes # CHECK-NOT: resultSegmentSizes # CHECK-SAME: : (i32, i32) -> (i8, i16) print(m) run(testOdsBuildDefaultNonVariadic) def testOdsBuildDefaultSizedVariadic(): class TestOp(OpView): OPERATION_NAME = "custom.test_op" _ODS_OPERAND_SEGMENTS = [1, -1, 0] _ODS_RESULT_SEGMENTS = [-1, 0, 1] with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True m = Module.create() with InsertionPoint(m.body): v0 = add_dummy_value() v1 = add_dummy_value() v2 = add_dummy_value() v3 = add_dummy_value() t0 = IntegerType.get_signless(8) t1 = IntegerType.get_signless(16) t2 = IntegerType.get_signless(32) t3 = IntegerType.get_signless(64) # CHECK: %[[V0:.+]] = "custom.value" # CHECK: %[[V1:.+]] = "custom.value" # CHECK: %[[V2:.+]] = "custom.value" # CHECK: %[[V3:.+]] = "custom.value" # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) # CHECK-SAME: operandSegmentSizes = array # CHECK-SAME: resultSegmentSizes = array # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) op = TestOp.build_generic( results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3] ) # Now test with optional omitted. # CHECK: "custom.test_op"(%[[V0]]) # CHECK-SAME: operandSegmentSizes = array # CHECK-SAME: resultSegmentSizes = array # CHECK-SAME: (i32) -> i64 op = TestOp.build_generic( results=[None, None, t3], operands=[v0, None, None] ) print(m) # And verify that errors are raised for None in a required operand. try: op = TestOp.build_generic( results=[None, None, t3], operands=[None, None, None] ) except ValueError as e: # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional) print(f"OPERAND_CAST_ERROR:{e}") # And verify that errors are raised for None in a required result. try: op = TestOp.build_generic( results=[None, None, None], operands=[v0, None, None] ) except ValueError as e: # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional) print(f"RESULT_CAST_ERROR:{e}") # Variadic lists with None elements should reject. try: op = TestOp.build_generic( results=[None, None, t3], operands=[v0, [None], None] ) except ValueError as e: # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item) print(f"OPERAND_LIST_CAST_ERROR:{e}") try: op = TestOp.build_generic( results=[[None], None, t3], operands=[v0, None, None] ) except ValueError as e: # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item) print(f"RESULT_LIST_CAST_ERROR:{e}") run(testOdsBuildDefaultSizedVariadic) def testOdsBuildDefaultCastError(): class TestOp(OpView): OPERATION_NAME = "custom.test_op" with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True m = Module.create() with InsertionPoint(m.body): v0 = add_dummy_value() v1 = add_dummy_value() t0 = IntegerType.get_signless(8) t1 = IntegerType.get_signless(16) try: op = TestOp.build_generic(results=[t0, t1], operands=[None, v1]) except ValueError as e: # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value print(f"ERROR: {e}") try: op = TestOp.build_generic(results=[t0, None], operands=[v0, v1]) except ValueError as e: # CHECK: Result 1 of operation "custom.test_op" must be a Type print(f"ERROR: {e}") run(testOdsBuildDefaultCastError)