[mlir][spirv] Add support for sampled image type

co-authored-by: Alan Liu <alanliu.yf@gmail.com>

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D96169
This commit is contained in:
Weiwei Li 2021-02-09 13:47:12 -05:00 committed by Lei Zhang
parent 7e7cfce0b6
commit 2ef24139fc
11 changed files with 207 additions and 14 deletions

View File

@ -274,6 +274,7 @@ spirv-type ::= array-type
| image-type
| pointer-type
| runtime-array-type
| sampled-image-type
| struct-type
```
@ -363,6 +364,22 @@ For example,
!spv.rtarray<i32, stride=4>
!spv.rtarray<vector<4 x f32>>
```
### Sampled image type
This corresponds to SPIR-V [sampled image type][SampledImageType]. Its syntax is
```
sampled-image-type ::= `!spv.sampled_image<!spv.image<` element-type `,` dim `,` depth-info `,`
arrayed-info `,` sampling-info `,`
sampler-use-info `,` format `>>`
```
For example,
```mlir
!spv.sampled_image<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>
!spv.sampled_image<!spv.image<i32, Rect, DepthUnknown, Arrayed, MultiSampled, NeedSampler, R8ui>>
```
### Struct type
@ -1382,6 +1399,7 @@ dialect.
[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
[PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer
[RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray
[SampledImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeSampledImage
[MlirDialectConversion]: ../DialectConversion.md
[StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure
[SpirvTools]: https://github.com/KhronosGroup/SPIRV-Tools

View File

@ -3158,6 +3158,7 @@ def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>;
def SPV_OC_OpTypeImage : I32EnumAttrCase<"OpTypeImage", 25>;
def SPV_OC_OpTypeSampledImage : I32EnumAttrCase<"OpTypeSampledImage", 27>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>;
@ -3317,18 +3318,19 @@ def SPV_OpcodeAttr :
SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, SPV_OC_OpTypeImage,
SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer,
SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory,
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic,
SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix,
SPV_OC_OpTypeImage, SPV_OC_OpTypeSampledImage, SPV_OC_OpTypeArray,
SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp,
SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd,
SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle,
SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,

View File

@ -32,6 +32,7 @@ struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct SampledImageTypeStorage;
struct StructTypeStorage;
} // namespace detail
@ -233,6 +234,28 @@ public:
Optional<StorageClass> storage = llvm::None);
};
// SPIR-V sampled image type
class SampledImageType
: public Type::TypeBase<SampledImageType, SPIRVType,
detail::SampledImageTypeStorage> {
public:
using Base::Base;
static SampledImageType get(Type imageType);
static SampledImageType getChecked(Type imageType, Location location);
static LogicalResult verifyConstructionInvariants(Location Loc,
Type imageType);
Type getImageType() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<spirv::StorageClass> storage = llvm::None);
};
/// SPIR-V struct type. Two kinds of struct types are supported:
/// - Literal: a literal struct type is uniqued by its fields (types + offset
/// info + decoration info).

View File

@ -116,7 +116,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
void SPIRVDialect::initialize() {
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
PointerType, RuntimeArrayType, StructType>();
PointerType, RuntimeArrayType, SampledImageType, StructType>();
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
@ -232,6 +232,23 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
return type;
}
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
Type type;
llvm::SMLoc typeLoc = parser.getCurrentLocation();
if (parser.parseType(type))
return Type();
if (!type.isa<ImageType>()) {
parser.emitError(typeLoc,
"sampled image must be composed using image type, got ")
<< type;
return Type();
}
return type;
}
/// Parses an optional `, stride = N` assembly segment. If no parsing failure
/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
/// missing.
@ -530,6 +547,21 @@ static Type parseImageType(SPIRVDialect const &dialect,
return ImageType::get(value.getValue());
}
// sampledImage-type :: = `!spv.sampledImage<` image-type `>`
static Type parseSampledImageType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
Type parsedType = parseAndVerifySampledImageType(dialect, parser);
if (!parsedType)
return Type();
if (parser.parseGreater())
return Type();
return SampledImageType::get(parsedType);
}
// Parse decorations associated with a member.
static ParseResult parseStructMemberDecorations(
SPIRVDialect const &dialect, DialectAsmParser &parser,
@ -707,6 +739,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
// | image-type
// | pointer-type
// | runtime-array-type
// | sampled-image-type
// | struct-type
Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
@ -723,6 +756,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parsePointerType(*this, parser);
if (keyword == "rtarray")
return parseRuntimeArrayType(*this, parser);
if (keyword == "sampled_image")
return parseSampledImageType(*this, parser);
if (keyword == "struct")
return parseStructType(*this, parser);
if (keyword == "matrix")
@ -763,6 +798,10 @@ static void print(ImageType type, DialectAsmPrinter &os) {
<< stringifyImageFormat(type.getImageFormat()) << ">";
}
static void print(SampledImageType type, DialectAsmPrinter &os) {
os << "sampled_image<" << type.getImageType() << ">";
}
static void print(StructType type, DialectAsmPrinter &os) {
thread_local llvm::SetVector<StringRef> structContext;
@ -825,7 +864,7 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
ImageType, StructType, MatrixType>(
ImageType, SampledImageType, StructType, MatrixType>(
[&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}

View File

@ -668,6 +668,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
compositeType.getExtensions(extensions, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
imageType.getExtensions(extensions, storage);
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
sampledImageType.getExtensions(extensions, storage);
} else if (auto matrixType = dyn_cast<MatrixType>()) {
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
@ -686,6 +688,8 @@ void SPIRVType::getCapabilities(
compositeType.getCapabilities(capabilities, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
imageType.getCapabilities(capabilities, storage);
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
sampledImageType.getCapabilities(capabilities, storage);
} else if (auto matrixType = dyn_cast<MatrixType>()) {
matrixType.getCapabilities(capabilities, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
@ -703,6 +707,56 @@ Optional<int64_t> SPIRVType::getSizeInBytes() {
return llvm::None;
}
//===----------------------------------------------------------------------===//
// SampledImageType
//===----------------------------------------------------------------------===//
struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
using KeyTy = Type;
SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<SampledImageTypeStorage>())
SampledImageTypeStorage(key);
}
Type imageType;
};
SampledImageType SampledImageType::get(Type imageType) {
return Base::get(imageType.getContext(), imageType);
}
SampledImageType SampledImageType::getChecked(Type imageType,
Location location) {
return Base::getChecked(location, imageType);
}
Type SampledImageType::getImageType() const { return getImpl()->imageType; }
LogicalResult SampledImageType::verifyConstructionInvariants(Location loc,
Type imageType) {
if (!imageType.isa<ImageType>())
return emitError(loc, "expected image type");
return success();
}
void SampledImageType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getImageType().cast<ImageType>().getExtensions(extensions, storage);
}
void SampledImageType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//

View File

@ -158,6 +158,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeArray:
case spirv::Opcode::OpTypeFunction:
case spirv::Opcode::OpTypeImage:
case spirv::Opcode::OpTypeSampledImage:
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:

View File

@ -715,6 +715,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processFunctionType(operands);
case spirv::Opcode::OpTypeImage:
return processImageType(operands);
case spirv::Opcode::OpTypeSampledImage:
return processSampledImageType(operands);
case spirv::Opcode::OpTypeRuntimeArray:
return processRuntimeArrayType(operands);
case spirv::Opcode::OpTypeStruct:
@ -1054,6 +1056,21 @@ spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult
spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
Type elementTy = getType(operands[1]);
if (!elementTy)
return emitError(unknownLoc,
"OpTypeSampledImage references undefined <id>: ")
<< operands[1];
typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
return success();
}
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//

View File

@ -275,6 +275,8 @@ private:
LogicalResult processImageType(ArrayRef<uint32_t> operands);
LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);
LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
LogicalResult processStructType(ArrayRef<uint32_t> operands);

View File

@ -511,6 +511,17 @@ LogicalResult Serializer::prepareBasicType(
return processTypeDecoration(loc, runtimeArrayType, resultID);
}
if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) {
typeEnum = spirv::Opcode::OpTypeSampledImage;
uint32_t imageTypeID = 0;
if (failed(
processType(loc, sampledImageType.getImageType(), imageTypeID))) {
return failure();
}
operands.push_back(imageTypeID);
return success();
}
if (auto structType = type.dyn_cast<spirv::StructType>()) {
if (structType.isIdentified()) {
(void)processName(resultID, structType.getIdentifier());

View File

@ -226,6 +226,20 @@ func private @image_parameters_nocomma_5(!spv.image<f32, Dim1D, NoDepth, NonArra
// -----
//===----------------------------------------------------------------------===//
// SampledImageType
//===----------------------------------------------------------------------===//
// CHECK: func private @sampled_image_type(!spv.sampled_image<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>)
func private @sampled_image_type(!spv.sampled_image<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>) -> ()
// -----
// expected-error @+1 {{sampled image must be composed using image type, got 'f32'}}
func private @samped_image_type_invaid_type(!spv.sampled_image<f32>) -> ()
// -----
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,12 @@
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: !spv.ptr<!spv.sampled_image<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, UniformConstant>
spv.globalVariable @var0 bind(0, 1) : !spv.ptr<!spv.sampled_image<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, UniformConstant>
// CHECK: !spv.ptr<!spv.sampled_image<!spv.image<si32, SubpassData, DepthUnknown, Arrayed, MultiSampled, NoSampler, Unknown>>, UniformConstant>
spv.globalVariable @var1 bind(0, 0) : !spv.ptr<!spv.sampled_image<!spv.image<si32, SubpassData, DepthUnknown, Arrayed, MultiSampled, NoSampler, Unknown>>, UniformConstant>
// CHECK: !spv.ptr<!spv.sampled_image<!spv.image<i32, Rect, DepthUnknown, Arrayed, MultiSampled, NeedSampler, R8ui>>, UniformConstant>
spv.globalVariable @var2 bind(0, 0) : !spv.ptr<!spv.sampled_image<!spv.image<i32, Rect, DepthUnknown, Arrayed, MultiSampled, NeedSampler, R8ui>>, UniformConstant>
}