diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 536cc136d0c4..e4d4ee56da06 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4327,6 +4327,14 @@ def SPIRV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPIRV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPIRV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; def SPIRV_OC_OpGroupBroadcast : I32EnumAttrCase<"OpGroupBroadcast", 263>; +def SPIRV_OC_OpGroupIAdd : I32EnumAttrCase<"OpGroupIAdd", 264>; +def SPIRV_OC_OpGroupFAdd : I32EnumAttrCase<"OpGroupFAdd", 265>; +def SPIRV_OC_OpGroupFMin : I32EnumAttrCase<"OpGroupFMin", 266>; +def SPIRV_OC_OpGroupUMin : I32EnumAttrCase<"OpGroupUMin", 267>; +def SPIRV_OC_OpGroupSMin : I32EnumAttrCase<"OpGroupSMin", 268>; +def SPIRV_OC_OpGroupFMax : I32EnumAttrCase<"OpGroupFMax", 269>; +def SPIRV_OC_OpGroupUMax : I32EnumAttrCase<"OpGroupUMax", 270>; +def SPIRV_OC_OpGroupSMax : I32EnumAttrCase<"OpGroupSMax", 271>; def SPIRV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>; def SPIRV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; def SPIRV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>; @@ -4356,6 +4364,8 @@ def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockRead def SPIRV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>; def SPIRV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>; def SPIRV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>; +def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>; +def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>; def SPIRV_OC_OpTypeJointMatrixINTEL : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>; def SPIRV_OC_OpJointMatrixLoadINTEL : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>; @@ -4365,58 +4375,69 @@ def SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatr def SPIRV_OpcodeAttr : SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ - SPIRV_OC_OpNop, SPIRV_OC_OpUndef, SPIRV_OC_OpSourceContinued, SPIRV_OC_OpSource, - SPIRV_OC_OpSourceExtension, SPIRV_OC_OpName, SPIRV_OC_OpMemberName, SPIRV_OC_OpString, - SPIRV_OC_OpLine, SPIRV_OC_OpExtension, SPIRV_OC_OpExtInstImport, SPIRV_OC_OpExtInst, + SPIRV_OC_OpNop, SPIRV_OC_OpUndef, SPIRV_OC_OpSourceContinued, + SPIRV_OC_OpSource, SPIRV_OC_OpSourceExtension, SPIRV_OC_OpName, + SPIRV_OC_OpMemberName, SPIRV_OC_OpString, SPIRV_OC_OpLine, + SPIRV_OC_OpExtension, SPIRV_OC_OpExtInstImport, SPIRV_OC_OpExtInst, SPIRV_OC_OpMemoryModel, SPIRV_OC_OpEntryPoint, SPIRV_OC_OpExecutionMode, - SPIRV_OC_OpCapability, SPIRV_OC_OpTypeVoid, SPIRV_OC_OpTypeBool, SPIRV_OC_OpTypeInt, - SPIRV_OC_OpTypeFloat, SPIRV_OC_OpTypeVector, SPIRV_OC_OpTypeMatrix, - SPIRV_OC_OpTypeImage, SPIRV_OC_OpTypeSampledImage, SPIRV_OC_OpTypeArray, - SPIRV_OC_OpTypeRuntimeArray, SPIRV_OC_OpTypeStruct, SPIRV_OC_OpTypePointer, - SPIRV_OC_OpTypeFunction, SPIRV_OC_OpTypeForwardPointer, SPIRV_OC_OpConstantTrue, - SPIRV_OC_OpConstantFalse, SPIRV_OC_OpConstant, SPIRV_OC_OpConstantComposite, - SPIRV_OC_OpConstantNull, SPIRV_OC_OpSpecConstantTrue, SPIRV_OC_OpSpecConstantFalse, - SPIRV_OC_OpSpecConstant, SPIRV_OC_OpSpecConstantComposite, SPIRV_OC_OpSpecConstantOp, - SPIRV_OC_OpFunction, SPIRV_OC_OpFunctionParameter, SPIRV_OC_OpFunctionEnd, - SPIRV_OC_OpFunctionCall, SPIRV_OC_OpVariable, SPIRV_OC_OpLoad, SPIRV_OC_OpStore, - SPIRV_OC_OpCopyMemory, SPIRV_OC_OpAccessChain, SPIRV_OC_OpPtrAccessChain, - SPIRV_OC_OpInBoundsPtrAccessChain, SPIRV_OC_OpDecorate, SPIRV_OC_OpMemberDecorate, - SPIRV_OC_OpVectorExtractDynamic, SPIRV_OC_OpVectorInsertDynamic, - SPIRV_OC_OpVectorShuffle, SPIRV_OC_OpCompositeConstruct, SPIRV_OC_OpCompositeExtract, + SPIRV_OC_OpCapability, SPIRV_OC_OpTypeVoid, SPIRV_OC_OpTypeBool, + SPIRV_OC_OpTypeInt, SPIRV_OC_OpTypeFloat, SPIRV_OC_OpTypeVector, + SPIRV_OC_OpTypeMatrix, SPIRV_OC_OpTypeImage, SPIRV_OC_OpTypeSampledImage, + SPIRV_OC_OpTypeArray, SPIRV_OC_OpTypeRuntimeArray, SPIRV_OC_OpTypeStruct, + SPIRV_OC_OpTypePointer, SPIRV_OC_OpTypeFunction, SPIRV_OC_OpTypeForwardPointer, + SPIRV_OC_OpConstantTrue, SPIRV_OC_OpConstantFalse, SPIRV_OC_OpConstant, + SPIRV_OC_OpConstantComposite, SPIRV_OC_OpConstantNull, + SPIRV_OC_OpSpecConstantTrue, SPIRV_OC_OpSpecConstantFalse, + SPIRV_OC_OpSpecConstant, SPIRV_OC_OpSpecConstantComposite, + SPIRV_OC_OpSpecConstantOp, SPIRV_OC_OpFunction, SPIRV_OC_OpFunctionParameter, + SPIRV_OC_OpFunctionEnd, SPIRV_OC_OpFunctionCall, SPIRV_OC_OpVariable, + SPIRV_OC_OpLoad, SPIRV_OC_OpStore, SPIRV_OC_OpCopyMemory, + SPIRV_OC_OpAccessChain, SPIRV_OC_OpPtrAccessChain, + SPIRV_OC_OpInBoundsPtrAccessChain, SPIRV_OC_OpDecorate, + SPIRV_OC_OpMemberDecorate, SPIRV_OC_OpVectorExtractDynamic, + SPIRV_OC_OpVectorInsertDynamic, SPIRV_OC_OpVectorShuffle, + SPIRV_OC_OpCompositeConstruct, SPIRV_OC_OpCompositeExtract, SPIRV_OC_OpCompositeInsert, SPIRV_OC_OpTranspose, SPIRV_OC_OpImageDrefGather, SPIRV_OC_OpImage, SPIRV_OC_OpImageQuerySize, SPIRV_OC_OpConvertFToU, SPIRV_OC_OpConvertFToS, SPIRV_OC_OpConvertSToF, SPIRV_OC_OpConvertUToF, - SPIRV_OC_OpUConvert, SPIRV_OC_OpSConvert, SPIRV_OC_OpFConvert, SPIRV_OC_OpPtrCastToGeneric, - SPIRV_OC_OpGenericCastToPtr, SPIRV_OC_OpGenericCastToPtrExplicit, SPIRV_OC_OpBitcast, - SPIRV_OC_OpSNegate, SPIRV_OC_OpFNegate, SPIRV_OC_OpIAdd, SPIRV_OC_OpFAdd, - SPIRV_OC_OpISub, SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv, - SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem, SPIRV_OC_OpSMod, - SPIRV_OC_OpFRem, SPIRV_OC_OpFMod, SPIRV_OC_OpVectorTimesScalar, - SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry, - SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan, - SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, - SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, - SPIRV_OC_OpSelect, SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan, - SPIRV_OC_OpSGreaterThan, SPIRV_OC_OpUGreaterThanEqual, SPIRV_OC_OpSGreaterThanEqual, - SPIRV_OC_OpULessThan, SPIRV_OC_OpSLessThan, SPIRV_OC_OpULessThanEqual, - SPIRV_OC_OpSLessThanEqual, SPIRV_OC_OpFOrdEqual, SPIRV_OC_OpFUnordEqual, - SPIRV_OC_OpFOrdNotEqual, SPIRV_OC_OpFUnordNotEqual, SPIRV_OC_OpFOrdLessThan, - SPIRV_OC_OpFUnordLessThan, SPIRV_OC_OpFOrdGreaterThan, SPIRV_OC_OpFUnordGreaterThan, - SPIRV_OC_OpFOrdLessThanEqual, SPIRV_OC_OpFUnordLessThanEqual, - SPIRV_OC_OpFOrdGreaterThanEqual, SPIRV_OC_OpFUnordGreaterThanEqual, - SPIRV_OC_OpShiftRightLogical, SPIRV_OC_OpShiftRightArithmetic, - SPIRV_OC_OpShiftLeftLogical, SPIRV_OC_OpBitwiseOr, SPIRV_OC_OpBitwiseXor, - SPIRV_OC_OpBitwiseAnd, SPIRV_OC_OpNot, SPIRV_OC_OpBitFieldInsert, - SPIRV_OC_OpBitFieldSExtract, SPIRV_OC_OpBitFieldUExtract, SPIRV_OC_OpBitReverse, - SPIRV_OC_OpBitCount, SPIRV_OC_OpControlBarrier, SPIRV_OC_OpMemoryBarrier, - SPIRV_OC_OpAtomicExchange, SPIRV_OC_OpAtomicCompareExchange, - SPIRV_OC_OpAtomicCompareExchangeWeak, SPIRV_OC_OpAtomicIIncrement, - SPIRV_OC_OpAtomicIDecrement, SPIRV_OC_OpAtomicIAdd, SPIRV_OC_OpAtomicISub, - SPIRV_OC_OpAtomicSMin, SPIRV_OC_OpAtomicUMin, SPIRV_OC_OpAtomicSMax, - SPIRV_OC_OpAtomicUMax, SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor, - SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge, SPIRV_OC_OpLabel, - SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional, SPIRV_OC_OpReturn, - SPIRV_OC_OpReturnValue, SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, + SPIRV_OC_OpUConvert, SPIRV_OC_OpSConvert, SPIRV_OC_OpFConvert, + SPIRV_OC_OpPtrCastToGeneric, SPIRV_OC_OpGenericCastToPtr, + SPIRV_OC_OpGenericCastToPtrExplicit, SPIRV_OC_OpBitcast, SPIRV_OC_OpSNegate, + SPIRV_OC_OpFNegate, SPIRV_OC_OpIAdd, SPIRV_OC_OpFAdd, SPIRV_OC_OpISub, + SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv, + SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem, + SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod, + SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, + SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow, + SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan, + SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, + SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, + SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect, + SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan, + SPIRV_OC_OpSGreaterThan, SPIRV_OC_OpUGreaterThanEqual, + SPIRV_OC_OpSGreaterThanEqual, SPIRV_OC_OpULessThan, SPIRV_OC_OpSLessThan, + SPIRV_OC_OpULessThanEqual, SPIRV_OC_OpSLessThanEqual, SPIRV_OC_OpFOrdEqual, + SPIRV_OC_OpFUnordEqual, SPIRV_OC_OpFOrdNotEqual, SPIRV_OC_OpFUnordNotEqual, + SPIRV_OC_OpFOrdLessThan, SPIRV_OC_OpFUnordLessThan, SPIRV_OC_OpFOrdGreaterThan, + SPIRV_OC_OpFUnordGreaterThan, SPIRV_OC_OpFOrdLessThanEqual, + SPIRV_OC_OpFUnordLessThanEqual, SPIRV_OC_OpFOrdGreaterThanEqual, + SPIRV_OC_OpFUnordGreaterThanEqual, SPIRV_OC_OpShiftRightLogical, + SPIRV_OC_OpShiftRightArithmetic, SPIRV_OC_OpShiftLeftLogical, + SPIRV_OC_OpBitwiseOr, SPIRV_OC_OpBitwiseXor, SPIRV_OC_OpBitwiseAnd, + SPIRV_OC_OpNot, SPIRV_OC_OpBitFieldInsert, SPIRV_OC_OpBitFieldSExtract, + SPIRV_OC_OpBitFieldUExtract, SPIRV_OC_OpBitReverse, SPIRV_OC_OpBitCount, + SPIRV_OC_OpControlBarrier, SPIRV_OC_OpMemoryBarrier, SPIRV_OC_OpAtomicExchange, + SPIRV_OC_OpAtomicCompareExchange, SPIRV_OC_OpAtomicCompareExchangeWeak, + SPIRV_OC_OpAtomicIIncrement, SPIRV_OC_OpAtomicIDecrement, + SPIRV_OC_OpAtomicIAdd, SPIRV_OC_OpAtomicISub, SPIRV_OC_OpAtomicSMin, + SPIRV_OC_OpAtomicUMin, SPIRV_OC_OpAtomicSMax, SPIRV_OC_OpAtomicUMax, + SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor, + SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge, + SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional, + SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, SPIRV_OC_OpUnreachable, + SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd, SPIRV_OC_OpGroupFAdd, + SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin, SPIRV_OC_OpGroupSMin, + SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax, SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed, SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot, SPIRV_OC_OpGroupNonUniformShuffle, SPIRV_OC_OpGroupNonUniformShuffleXor, @@ -4430,7 +4451,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV, SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL, - SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, + SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR, + SPIRV_OC_OpGroupFMulKHR, SPIRV_OC_OpTypeJointMatrixINTEL, SPIRV_OC_OpJointMatrixLoadINTEL, SPIRV_OC_OpJointMatrixStoreINTEL, SPIRV_OC_OpJointMatrixMadINTEL, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td index 8c43107a0dc9..0d2f416947c5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td @@ -17,6 +17,69 @@ // ----- +def SPIRV_GroupFMulKHROp : SPIRV_KhrVendorOp<"GroupFMul", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + A floating-point multiplication group operation specified for all values of + 'X' specified by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + 'Execution' reach this point of execution. + + Behavior is undefined unless all invocations within 'Execution' execute the + same dynamic instance of this instruction. + + 'Result Type' must be a scalar or vector of floating-point type. + + 'Execution' is a Scope. It must be either Workgroup or Subgroup. + + The identity I for 'Operation' is 1. + + The type of 'X' must be the same as 'Result Type'. + + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.KHR.GroupFMul` scope operation ssa-use + `:` float-type + ```mlir + + #### Example: + + ``` + %0 = spirv.KHR.GroupFMul %value : f32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_GroupUniformArithmeticKHR]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} + +// ----- + def SPIRV_GroupBroadcastOp : SPIRV_Op<"GroupBroadcast", [Pure, AllTypesMatch<["value", "result"]>]> { @@ -93,56 +156,564 @@ def SPIRV_GroupBroadcastOp : SPIRV_Op<"GroupBroadcast", // ----- -def SPIRV_KHRSubgroupBallotOp : SPIRV_KhrVendorOp<"SubgroupBallot", []> { - let summary = "See extension SPV_KHR_shader_ballot"; +def SPIRV_GroupFAddOp : SPIRV_Op<"GroupFAdd", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + A floating-point add group operation specified for all values of X + specified by invocations in the group. + }]; let description = [{ - Computes a bitfield value combining the Predicate value from all invocations - in the current Subgroup that execute the same dynamic instance of this - instruction. The bit is set to one if the corresponding invocation is active - and the predicate is evaluated to true; otherwise, it is set to zero. + Behavior is undefined if not all invocations of this module within + Execution reach this point of execution. - Predicate must be a Boolean type. + Behavior is undefined unless all invocations within Execution execute + the same dynamic instance of this instruction. - Result Type must be a 4 component vector of 32 bit integer types. + Result Type must be a scalar or vector of floating-point type. - Result is a set of bitfields where the first invocation is represented in bit - 0 of the first vector component and the last (up to SubgroupSize) is the - higher bit number of the last bitmask needed to represent all bits of the - subgroup invocations. + Execution is a Scope. It must be either Workgroup or Subgroup. + + The identity I for Operation is 0. + + The type of X must be the same as Result Type. ``` - subgroup-ballot-op ::= ssa-id `=` `spirv.KHR.SubgroupBallot` - ssa-use `:` `vector` `<` 4 `x` `i32` `>` - ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.GroupFAdd` scope operation ssa-use + `:` float-type + ```mlir #### Example: - ```mlir - %0 = spirv.KHR.SubgroupBallot %predicate : vector<4xi32> + ``` + %0 = spirv.GroupFAdd %value : f32 ``` }]; let availability = [ MinVersion, MaxVersion, - Extension<[SPV_KHR_shader_ballot]>, - Capability<[SPIRV_C_SubgroupBallotKHR]> + Extension<[]>, + Capability<[SPIRV_C_Groups]> ]; let arguments = (ins - SPIRV_Bool:$predicate + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x ); let results = (outs - SPIRV_Int32Vec4:$result + SPIRV_ScalarOrVectorOf:$result ); - let hasVerifier = 0; + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} - let assemblyFormat = "$predicate attr-dict `:` type($result)"; +// ----- + +def SPIRV_GroupFMaxOp : SPIRV_Op<"GroupFMax", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + A floating-point maximum group operation specified for all values of X + specified by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + Execution reach this point of execution. + + Behavior is undefined unless all invocations within Execution execute + the same dynamic instance of this instruction. + + Result Type must be a scalar or vector of floating-point type. + + Execution is a Scope. It must be either Workgroup or Subgroup. + + The identity I for Operation is -INF. + + The type of X must be the same as Result Type. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.GroupFMax` scope operation ssa-use + `:` float-type + ```mlir + + #### Example: + + ``` + %0 = spirv.GroupFMax %value : f32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Groups]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} + +// ----- + +def SPIRV_GroupFMinOp : SPIRV_Op<"GroupFMin", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + A floating-point minimum group operation specified for all values of X + specified by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + Execution reach this point of execution. + + Behavior is undefined unless all invocations within Execution execute + the same dynamic instance of this instruction. + + Result Type must be a scalar or vector of floating-point type. + + Execution is a Scope. It must be either Workgroup or Subgroup. + + The identity I for Operation is +INF. + + The type of X must be the same as Result Type. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.GroupFMin` scope operation ssa-use + `:` float-type + ```mlir + + #### Example: + + ``` + %0 = spirv.GroupFMin %value : f32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Groups]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} + +// ----- + +def SPIRV_GroupIAddOp : SPIRV_Op<"GroupIAdd", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + An integer add group operation specified for all values of X specified + by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + Execution reach this point of execution. + + Behavior is undefined unless all invocations within Execution execute + the same dynamic instance of this instruction. + + Result Type must be a scalar or vector of integer type. + + Execution is a Scope. It must be either Workgroup or Subgroup. + + The identity I for Operation is 0. + + The type of X must be the same as Result Type. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.GroupIAdd` scope operation ssa-use + `:` integer-type + ```mlir + + #### Example: + + ``` + %0 = spirv.GroupIAdd %value : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Groups]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} + +// ----- + +def SPIRV_GroupIMulKHROp : SPIRV_KhrVendorOp<"GroupIMul", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + An integer multiplication group operation specified for all values of 'X' + specified by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + 'Execution' reach this point of execution. + + Behavior is undefined unless all invocations within 'Execution' execute the + same dynamic instance of this instruction. + + 'Result Type' must be a scalar or vector of integer type. + + 'Execution' is a Scope. It must be either Workgroup or Subgroup. + + The identity I for 'Operation' is 1. + + The type of 'X' must be the same as 'Result Type'. + + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.KHR.GroupIMul` scope operation ssa-use + `:` integer-type + ```mlir + + #### Example: + + ``` + %0 = spirv.KHR.GroupIMul %value : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_GroupUniformArithmeticKHR]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} + +// ----- + +def SPIRV_GroupSMaxOp : SPIRV_Op<"GroupSMax", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + A signed integer maximum group operation specified for all values of X + specified by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + Execution reach this point of execution. + + Behavior is undefined unless all invocations within Execution execute + the same dynamic instance of this instruction. + + Result Type must be a scalar or vector of integer type. + + Execution is a Scope. It must be either Workgroup or Subgroup. + + The identity I for Operation is INT_MIN when X is 32 bits wide and + LONG_MIN when X is 64 bits wide. + + The type of X must be the same as Result Type. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.GroupSMax` scope operation ssa-use + `:` integer-type + ```mlir + + #### Example: + + ``` + %0 = spirv.GroupSMax %value : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Groups]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} + +// ----- + +def SPIRV_GroupSMinOp : SPIRV_Op<"GroupSMin", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + A signed integer minimum group operation specified for all values of X + specified by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + Execution reach this point of execution. + + Behavior is undefined unless all invocations within Execution execute + the same dynamic instance of this instruction. + + Result Type must be a scalar or vector of integer type. + + Execution is a Scope. It must be either Workgroup or Subgroup. + + The identity I for Operation is INT_MAX when X is 32 bits wide and + LONG_MAX when X is 64 bits wide. + + The type of X must be the same as Result Type. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.GroupSMin` scope operation ssa-use + `:` integer-type + ```mlir + + #### Example: + + ``` + %0 = spirv.GroupSMin %value : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Groups]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} + +// ----- + +def SPIRV_GroupUMaxOp : SPIRV_Op<"GroupUMax", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + An unsigned integer maximum group operation specified for all values of + X specified by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + Execution reach this point of execution. + + Behavior is undefined unless all invocations within Execution execute + the same dynamic instance of this instruction. + + Result Type must be a scalar or vector of integer type. + + Execution is a Scope. It must be either Workgroup or Subgroup. + + The identity I for Operation is 0. + + The type of X must be the same as Result Type. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.GroupUMax` scope operation ssa-use + `:` integer-type + ```mlir + + #### Example: + + ``` + %0 = spirv.GroupUMax %value : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Groups]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; +} + +// ----- + +def SPIRV_GroupUMinOp : SPIRV_Op<"GroupUMin", [Pure, + AllTypesMatch<["x", "result"]>]> { + let summary = [{ + An unsigned integer minimum group operation specified for all values of + X specified by invocations in the group. + }]; + + let description = [{ + Behavior is undefined if not all invocations of this module within + Execution reach this point of execution. + + Behavior is undefined unless all invocations within Execution execute + the same dynamic instance of this instruction. + + Result Type must be a scalar or vector of integer type. + + Execution is a Scope. It must be either Workgroup or Subgroup. + + The identity I for Operation is UINT_MAX when X is 32 bits wide and + ULONG_MAX when X is 64 bits wide. + + The type of X must be the same as Result Type. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` + op ::= ssa-id `=` `spirv.GroupUMin` scope operation ssa-use + `:` integer-type + ```mlir + + #### Example: + + ``` + %0 = spirv.GroupUMin %value : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Groups]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + SPIRV_GroupOperationAttr:$group_operation, + SPIRV_ScalarOrVectorOf:$x + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $execution_scope $group_operation operands attr-dict `:` type($x) + }]; } // ----- @@ -247,4 +818,58 @@ def SPIRV_INTELSubgroupBlockWriteOp : SPIRV_IntelVendorOp<"SubgroupBlockWrite", // ----- +def SPIRV_KHRSubgroupBallotOp : SPIRV_KhrVendorOp<"SubgroupBallot", []> { + let summary = "See extension SPV_KHR_shader_ballot"; + + let description = [{ + Computes a bitfield value combining the Predicate value from all invocations + in the current Subgroup that execute the same dynamic instance of this + instruction. The bit is set to one if the corresponding invocation is active + and the predicate is evaluated to true; otherwise, it is set to zero. + + Predicate must be a Boolean type. + + Result Type must be a 4 component vector of 32 bit integer types. + + Result is a set of bitfields where the first invocation is represented in bit + 0 of the first vector component and the last (up to SubgroupSize) is the + higher bit number of the last bitmask needed to represent all bits of the + subgroup invocations. + + + + ``` + subgroup-ballot-op ::= ssa-id `=` `spirv.KHR.SubgroupBallot` + ssa-use `:` `vector` `<` 4 `x` `i32` `>` + ``` + + #### Example: + + ```mlir + %0 = spirv.KHR.SubgroupBallot %predicate : vector<4xi32> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_KHR_shader_ballot]>, + Capability<[SPIRV_C_SubgroupBallotKHR]> + ]; + + let arguments = (ins + SPIRV_Bool:$predicate + ); + + let results = (outs + SPIRV_Int32Vec4:$result + ); + + let hasVerifier = 0; + + let assemblyFormat = "$predicate attr-dict `:` type($result)"; +} + +// ----- + #endif // MLIR_DIALECT_SPIRV_IR_GROUP_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 606562292ebc..9a509796e5c1 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -4770,6 +4770,39 @@ LogicalResult spirv::VectorTimesScalarOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Group ops +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyGroupOp(Op op) { + spirv::Scope scope = op.getExecutionScope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); + + return success(); +} + +LogicalResult spirv::GroupIAddOp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupFAddOp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupFMinOp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupUMinOp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupSMinOp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupFMaxOp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupUMaxOp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupSMaxOp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); } + +LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); } + // TableGen'erated operation interfaces for querying versions, extensions, and // capabilities. #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir index 061d3f0d798b..741081a37d8a 100644 --- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir @@ -112,3 +112,87 @@ func.func @subgroup_block_write_intel_vector(%ptr : !spirv.ptr return } + +// ----- + +//===----------------------------------------------------------------------===// +// Group ops +//===----------------------------------------------------------------------===// + +func.func @group_iadd(%value: i32) -> i32 { + // CHECK: spirv.GroupIAdd %{{.*}} : i32 + %0 = spirv.GroupIAdd %value : i32 + return %0: i32 +} + +// ----- + +func.func @group_fadd(%value: f32) -> f32 { + // CHECK: spirv.GroupFAdd %{{.*}} : f32 + %0 = spirv.GroupFAdd %value : f32 + return %0: f32 +} + +// ----- + +func.func @group_fmin(%value: f32) -> f32 { + // CHECK: spirv.GroupFMin %{{.*}} : f32 + %0 = spirv.GroupFMin %value : f32 + return %0: f32 +} + +// ----- + +func.func @group_umin(%value: i32) -> i32 { + // CHECK: spirv.GroupUMin %{{.*}} : i32 + %0 = spirv.GroupUMin %value : i32 + return %0: i32 +} + +// ----- + +func.func @group_smin(%value: i32) -> i32 { + // CHECK: spirv.GroupSMin %{{.*}} : i32 + %0 = spirv.GroupSMin %value : i32 + return %0: i32 +} + +// ----- + +func.func @group_fmax(%value: f32) -> f32 { + // CHECK: spirv.GroupFMax %{{.*}} : f32 + %0 = spirv.GroupFMax %value : f32 + return %0: f32 +} + +// ----- + +func.func @group_umax(%value: i32) -> i32 { + // CHECK: spirv.GroupUMax %{{.*}} : i32 + %0 = spirv.GroupUMax %value : i32 + return %0: i32 +} + +// ----- + +func.func @group_smax(%value: i32) -> i32 { + // CHECK: spirv.GroupSMax %{{.*}} : i32 + %0 = spirv.GroupSMax %value : i32 + return %0: i32 +} + +// ----- + +func.func @group_imul(%value: i32) -> i32 { + // CHECK: spirv.KHR.GroupIMul %{{.*}} : i32 + %0 = spirv.KHR.GroupIMul %value : i32 + return %0: i32 +} + +// ----- + +func.func @group_fmul(%value: f32) -> f32 { + // CHECK: spirv.KHR.GroupFMul %{{.*}} : f32 + %0 = spirv.KHR.GroupFMul %value : f32 + return %0: f32 +} diff --git a/mlir/test/Target/SPIRV/group-ops.mlir b/mlir/test/Target/SPIRV/group-ops.mlir index 5a17e89463f7..dc07f8c8ef61 100644 --- a/mlir/test/Target/SPIRV/group-ops.mlir +++ b/mlir/test/Target/SPIRV/group-ops.mlir @@ -43,4 +43,65 @@ spirv.module Logical GLSL450 requires #spirv.vce { spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : vector<3xi32> spirv.Return } + // CHECK-LABEL: @group_iadd + spirv.func @group_iadd(%value: i32) -> i32 "None" { + // CHECK: spirv.GroupIAdd %{{.*}} : i32 + %0 = spirv.GroupIAdd %value : i32 + spirv.ReturnValue %0: i32 + } + // CHECK-LABEL: @group_fadd + spirv.func @group_fadd(%value: f32) -> f32 "None" { + // CHECK: spirv.GroupFAdd %{{.*}} : f32 + %0 = spirv.GroupFAdd %value : f32 + spirv.ReturnValue %0: f32 + } + // CHECK-LABEL: @group_fmin + spirv.func @group_fmin(%value: f32) -> f32 "None" { + // CHECK: spirv.GroupFMin %{{.*}} : f32 + %0 = spirv.GroupFMin %value : f32 + spirv.ReturnValue %0: f32 + } + // CHECK-LABEL: @group_umin + spirv.func @group_umin(%value: i32) -> i32 "None" { + // CHECK: spirv.GroupUMin %{{.*}} : i32 + %0 = spirv.GroupUMin %value : i32 + spirv.ReturnValue %0: i32 + } + // CHECK-LABEL: @group_smin + spirv.func @group_smin(%value: i32) -> i32 "None" { + // CHECK: spirv.GroupSMin %{{.*}} : i32 + %0 = spirv.GroupSMin %value : i32 + spirv.ReturnValue %0: i32 + } + // CHECK-LABEL: @group_fmax + spirv.func @group_fmax(%value: f32) -> f32 "None" { + // CHECK: spirv.GroupFMax %{{.*}} : f32 + %0 = spirv.GroupFMax %value : f32 + spirv.ReturnValue %0: f32 + } + // CHECK-LABEL: @group_umax + spirv.func @group_umax(%value: i32) -> i32 "None" { + // CHECK: spirv.GroupUMax %{{.*}} : i32 + %0 = spirv.GroupUMax %value : i32 + spirv.ReturnValue %0: i32 + } + // CHECK-LABEL: @group_smax + spirv.func @group_smax(%value: i32) -> i32 "None" { + // CHECK: spirv.GroupSMax %{{.*}} : i32 + %0 = spirv.GroupSMax %value : i32 + spirv.ReturnValue %0: i32 + } + // CHECK-LABEL: @group_imul + spirv.func @group_imul(%value: i32) -> i32 "None" { + // CHECK: spirv.KHR.GroupIMul %{{.*}} : i32 + %0 = spirv.KHR.GroupIMul %value : i32 + spirv.ReturnValue %0: i32 + } + // CHECK-LABEL: @group_fmul + spirv.func @group_fmul(%value: f32) -> f32 "None" { + // CHECK: spirv.KHR.GroupFMul %{{.*}} : f32 + %0 = spirv.KHR.GroupFMul %value : f32 + spirv.ReturnValue %0: f32 + } + } diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index ce9e8c66c63c..94ab267ae311 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -568,7 +568,8 @@ def update_td_opcodes(path, instructions, filter_list): assert len(content) == 3 # Extend opcode list with existing list - existing_opcodes = [k[11:] for k in re.findall('def SPIRV_OC_\w+', content[1])] + prefix = 'def SPIRV_OC_' + existing_opcodes = [k[len(prefix):] for k in re.findall(prefix + '\w+', content[1])] filter_list.extend(existing_opcodes) filter_list = list(set(filter_list))