[mlir][spirv] Support coop matrix in spirv.CompositeConstruct (#66399)

Also improve the documentation (code and website).
This commit is contained in:
Jakub Kuderski 2023-09-14 16:57:59 -04:00 committed by GitHub
parent 571e4f233b
commit 12175bcbce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 24 deletions

View File

@ -53,7 +53,15 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> {
#### Example:
```mlir
%0 = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
%a = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
%b = spirv.CompositeConstruct %a, %1 : (vector<3xf32>, f32) -> vector<4xf32>
%c = spirv.CompositeConstruct %1 :
(f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
%d = spirv.CompositeConstruct %a, %4, %5 :
(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) ->
!spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
```
}];

View File

@ -29,6 +29,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
@ -363,31 +364,35 @@ LogicalResult spirv::AddressOfOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::CompositeConstructOp::verify() {
auto cType = llvm::cast<spirv::CompositeType>(getType());
operand_range constituents = this->getConstituents();
if (auto coopType = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(cType)) {
// There are 4 cases with varying verification rules:
// 1. Cooperative Matrices (1 constituent)
// 2. Structs (1 constituent for each member)
// 3. Arrays (1 constituent for each array element)
// 4. Vectors (1 constituent (sub-)element for each vector element)
auto coopElementType =
llvm::TypeSwitch<Type, Type>(getType())
.Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
spirv::JointMatrixINTELType>(
[](auto coopType) { return coopType.getElementType(); })
.Default([](Type) { return nullptr; });
// Case 1. -- matrices.
if (coopElementType) {
if (constituents.size() != 1)
return emitOpError("has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
if (coopType.getElementType() != constituents.front().getType())
if (coopElementType != constituents.front().getType())
return emitOpError("operand type mismatch: expected operand type ")
<< coopType.getElementType() << ", but provided "
<< constituents.front().getType();
return success();
}
if (auto jointType = llvm::dyn_cast<spirv::JointMatrixINTELType>(cType)) {
if (constituents.size() != 1)
return emitOpError("has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
if (jointType.getElementType() != constituents.front().getType())
return emitOpError("operand type mismatch: expected operand type ")
<< jointType.getElementType() << ", but provided "
<< coopElementType << ", but provided "
<< constituents.front().getType();
return success();
}
// Case 2./3./4. -- number of constituents matches the number of elements.
auto cType = llvm::cast<spirv::CompositeType>(getType());
if (constituents.size() == cType.getNumElements()) {
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
if (constituents[index].getType() != cType.getElementType(index)) {
@ -399,8 +404,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
return success();
}
// If not constructing a cooperative matrix type, then we must be constructing
// a vector type.
// Case 4. -- check that all constituents add up tp the expected vector type.
auto resultType = llvm::dyn_cast<VectorType>(cType);
if (!resultType)
return emitOpError(

View File

@ -4,22 +4,20 @@
// spirv.CompositeConstruct
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @composite_construct_vector
func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
// CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
return %0: vector<3xf32>
}
// -----
// CHECK-LABEL: func @composite_construct_struct
func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
// CHECK: spirv.CompositeConstruct
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
}
// -----
// CHECK-LABEL: func @composite_construct_mixed_scalar_vector
func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
// CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
@ -27,9 +25,15 @@ func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2
return %0: vector<4xf32>
}
// -----
// CHECK-LABEL: func @composite_construct_coopmatrix_khr
func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
// CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
%0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
}
func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// CHECK-LABEL: func @composite_construct_coopmatrix_nv
func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
%0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
@ -53,6 +57,24 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg
// -----
func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) ->
!spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
// expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
}
// -----
func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) ->
!spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> {
// expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
%0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
}
// -----
func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>