[mlir][spirv] Support spec constants as GlobalVar initializer (#75660)

Changes include:

- spirv serialization and deserialization needs handling in cases when
GlobalVariableOp initializer is defined using spirv SpecConstant or
SpecConstantComposite op, currently even though it allows SpecConst, it
only looked up in for GlobalVariable Map to find initializer symbol
reference, change is fixing this and extending the support to
SpecConstantComposite as an initializer.
- Adds tests to make sure GlobalVariable can be initialized using
specialized constants.

---------

Co-authored-by: Lei Zhang <antiagainst@gmail.com>
This commit is contained in:
Dimple Prajapati 2024-01-05 16:27:30 -08:00 committed by GitHub
parent 651a42ff65
commit 5e54319b7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 14 deletions

View File

@ -1162,10 +1162,11 @@ LogicalResult spirv::GlobalVariableOp::verify() {
// TODO: Currently only variable initialization with specialization
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
if (!initOp ||
!isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
spirv::SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
"spirv.SpecConstant or spirv.GlobalVariable op");
"spirv.SpecConstant or spirv.GlobalVariable or "
"spirv.SpecConstantCompositeOp op");
}
}

View File

@ -637,14 +637,22 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
// Initializer.
FlatSymbolRefAttr initializer = nullptr;
if (wordIndex < operands.size()) {
auto initializerOp = getGlobalVariable(operands[wordIndex]);
if (!initializerOp) {
Operation *op = nullptr;
if (auto initOp = getGlobalVariable(operands[wordIndex]))
op = initOp;
else if (auto initOp = getSpecConstant(operands[wordIndex]))
op = initOp;
else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
op = initOp;
else
return emitError(unknownLoc, "unknown <id> ")
<< operands[wordIndex] << "used as initializer";
}
initializer = SymbolRefAttr::get(op);
wordIndex++;
initializer = SymbolRefAttr::get(initializerOp.getOperation());
}
if (wordIndex != operands.size()) {
return emitError(unknownLoc,

View File

@ -383,20 +383,31 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
// Encode initialization.
if (auto initializer = varOp.getInitializer()) {
auto initializerID = getVariableID(*initializer);
if (!initializerID) {
StringRef initAttrName = varOp.getInitializerAttrName().getValue();
if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
uint32_t initializerID = 0;
auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
varOp->getParentOp(), initRef.getAttr());
// Check if initializer is GlobalVariable or SpecConstant* cases.
if (isa<spirv::GlobalVariableOp>(initOp))
initializerID = getVariableID(*initSymbolName);
else
initializerID = getSpecConstID(*initSymbolName);
if (!initializerID)
return emitError(varOp.getLoc(),
"invalid usage of undefined variable as initializer");
}
operands.push_back(initializerID);
elidedAttrs.push_back("initializer");
elidedAttrs.push_back(initAttrName);
}
if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
return failure();
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
elidedAttrs.push_back("initializer");
elidedAttrs.push_back(initAttrName);
// Encode decorations.
for (auto attr : varOp->getAttrs()) {

View File

@ -349,6 +349,19 @@ spirv.SpecConstant @sc = 4.0 : f32
// CHECK: spirv.GlobalVariable @var initializer(@sc)
spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<f32, Private>
// -----
// Allow SpecConstantComposite as initializer
spirv.module Logical GLSL450 {
spirv.SpecConstant @sc1 = 1 : i8
spirv.SpecConstant @sc2 = 2 : i8
spirv.SpecConstant @sc3 = 3 : i8
spirv.SpecConstantComposite @scc (@sc1, @sc2, @sc3) : !spirv.array<3 x i8>
// CHECK: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
}
// -----
spirv.module Logical GLSL450 {
@ -410,7 +423,7 @@ spirv.module Logical GLSL450 {
// -----
spirv.module Logical GLSL450 {
// expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable op}}
// expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable or spirv.SpecConstantCompositeOp op}}
spirv.GlobalVariable @var0 initializer(@var1) : !spirv.ptr<f32, Private>
}

View File

@ -23,6 +23,30 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: spirv.SpecConstant @sc = 1 : i8
// CHECK-NEXT: spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
spirv.SpecConstant @sc = 1 : i8
spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
}
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
// CHECK-NEXT: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
spirv.SpecConstant @sc0 = 1 : i8
spirv.SpecConstant @sc1 = 2 : i8
spirv.SpecConstant @sc2 = 3 : i8
spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
}
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.GlobalVariable @globalInvocationID built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
spirv.func @foo() "None" {