diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 56813c403a67..d63837c585c9 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3040,9 +3040,6 @@ def SPV_IntVec4 : SPV_Vec4; def SPV_IOrUIVec4 : SPV_Vec4; def SPV_Int32Vec4 : SPV_Vec4; -// TODO(antiagainst): Use a more appropriate way to model optional operands -class SPV_Optional : Variadic; - // TODO(ravishankarm): From 1.4, this should also include Composite type. def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index bb39d6fbf9f8..864c9a563d32 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -240,7 +240,7 @@ def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [ ); let results = (outs - SPV_Optional:$result + Optional:$result ); let autogenSerialization = 0; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td index b3ba0f6d0b76..36b0879669b9 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -30,7 +30,7 @@ class SPV_GroupNonUniformArithmeticOp:$value, - SPV_Optional:$cluster_size + Optional:$cluster_size ); let results = (outs diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 71bf7bdceb3b..c8932652bdfa 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -469,7 +469,7 @@ def SPV_VariableOp : SPV_Op<"Variable", []> { let arguments = (ins SPV_StorageClassAttr:$storage_class, - SPV_Optional:$initializer + Optional:$initializer ); let results = (outs diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp index c4dcefc38d6e..e9acab21fc62 100644 --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -155,7 +155,7 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite( groupOperation = rewriter.create( \ loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ spirv::GroupOperation::Reduce, inputElement, \ - /*cluster_size=*/ArrayRef()); \ + /*cluster_size=*/nullptr); \ } break switch (*binaryOpKind) { CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 9c4670932d73..5ee6fb5c05af 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -2291,6 +2291,10 @@ Deserializer::processOp(ArrayRef operands) { << operands[0]; } + // Use null type to mean no result type. + if (isVoidType(resultType)) + resultType = nullptr; + auto resultID = operands[1]; auto functionID = operands[2]; @@ -2306,18 +2310,12 @@ Deserializer::processOp(ArrayRef operands) { arguments.push_back(value); } - SmallVector resultTypes; - if (!isVoidType(resultType)) { - resultTypes.push_back(resultType); - } - auto opFunctionCall = opBuilder.create( - unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName), + unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName), arguments); - if (!resultTypes.empty()) { + if (resultType) valueMap[resultID] = opFunctionCall.getResult(0); - } return success(); } diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir index 4951171bbca0..97ee02d45f24 100644 --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -202,7 +202,7 @@ func @caller() { spv.module Logical GLSL450 { spv.func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () "None" { - // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}} + // expected-error @+1 {{result group starting at #0 requires 0 or 1 element, but found 2}} %0:2 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32) spv.Return } diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index ce7cda4fe3cf..5854a74509cd 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -548,7 +548,7 @@ def map_spec_operand_to_ods_argument(operand): if quantifier == '': arg_type = 'SPV_Type' elif quantifier == '?': - arg_type = 'SPV_Optional' + arg_type = 'Optional' else: arg_type = 'Variadic' elif kind == 'IdMemorySemantics' or kind == 'IdScope':