[mlir][spirv] Fix nullptr dereference in UnifyAliasedResource

Fixes: https://github.com/llvm/llvm-project/issues/62368

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149376
This commit is contained in:
Jakub Kuderski 2023-04-28 11:39:22 -04:00
parent d636bcb6ae
commit 797594a043
2 changed files with 23 additions and 4 deletions

View File

@ -220,6 +220,9 @@ ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
}
bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
if (!op)
return false;
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
auto canonicalOp = getCanonicalResource(varOp);
return canonicalOp && varOp != canonicalOp;
@ -566,16 +569,15 @@ public:
private:
spirv::GetTargetEnvFn getTargetEnvFn;
};
} // namespace
void UnifyAliasedResourcePass::runOnOperation() {
spirv::ModuleOp moduleOp = getOperation();
MLIRContext *context = &getContext();
if (getTargetEnvFn) {
// This pass is only needed for targeting WebGPU, Metal, or layering Vulkan
// on Metal via MoltenVK, where we need to translate SPIR-V into WGSL or
// MSL. The translation has limitations.
// This pass is only needed for targeting WebGPU, Metal, or layering
// Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
// WGSL or MSL. The translation has limitations.
spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
bool isVulkanOnAppleDevices =
@ -614,6 +616,7 @@ void UnifyAliasedResourcePass::runOnOperation() {
resources.front()->removeAttr("aliased");
}
}
} // namespace
std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {

View File

@ -506,3 +506,19 @@ spirv.module Logical GLSL450 {
// CHECK: %[[CC:.+]] = spirv.CompositeConstruct %[[BC0]], %[[BC1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32>
// CHECK: spirv.ReturnValue %[[CC]]
// -----
// Make sure we do not crash on function arguments.
spirv.module Logical GLSL450 {
spirv.func @main(%arg0: !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>) "None" {
%cst0_i32 = spirv.Constant 0 : i32
%0 = spirv.AccessChain %arg0[%cst0_i32, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
spirv.Return
}
}
// CHECK-LABEL: spirv.module
// CHECK-LABEL: spirv.func @main
// CHECK-SAME: (%{{.+}}: !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>) "None"
// CHECK: spirv.Return