Fix handling of CopyObject in GetPtr and its call sites

This commit is contained in:
GregF 2017-07-18 14:42:51 -06:00 committed by Lei Zhang
parent e9e4393b1c
commit adb237f3bd
6 changed files with 159 additions and 76 deletions

View File

@ -47,6 +47,10 @@ ir::Instruction* AggressiveDCEPass::GetPtr(
*varId = ip->GetSingleWordInOperand(
op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
while (ptrInst->opcode() == SpvOpCopyObject) {
*varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
ptrInst = def_use_mgr_->GetDef(*varId);
}
ir::Instruction* varInst = ptrInst;
while (varInst->opcode() != SpvOpVariable) {
if (IsNonPtrAccessChain(varInst->opcode())) {

View File

@ -14,22 +14,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iterator.h"
#include "local_access_chain_convert_pass.h"
static const int kSpvEntryPointFunctionId = 1;
static const int kSpvStorePtrId = 0;
static const int kSpvStoreValId = 1;
static const int kSpvLoadPtrId = 0;
static const int kSpvAccessChainPtrId = 0;
static const int kSpvTypePointerStorageClass = 0;
static const int kSpvTypePointerTypeId = 1;
static const int kSpvConstantValue = 0;
static const int kSpvTypeIntWidth = 0;
#include "iterator.h"
namespace spvtools {
namespace opt {
namespace {
const uint32_t kEntryPointFunctionIdInIdx = 1;
const uint32_t kStorePtrIdInIdx = 0;
const uint32_t kStoreValIdInIdx = 1;
const uint32_t kLoadPtrIdInIdx = 0;
const uint32_t kAccessChainPtrIdInIdx = 0;
const uint32_t kTypePointerStorageClassInIdx = 0;
const uint32_t kTypePointerTypeIdInIdx = 1;
const uint32_t kConstantValueInIdx = 0;
const uint32_t kTypeIntWidthInIdx = 0;
const uint32_t kCopyObjectOperandInIdx = 0;
} // anonymous namespace
bool LocalAccessChainConvertPass::IsNonPtrAccessChain(
const SpvOp opcode) const {
return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
@ -68,14 +74,27 @@ bool LocalAccessChainConvertPass::IsTargetType(
}
ir::Instruction* LocalAccessChainConvertPass::GetPtr(
ir::Instruction* ip,
uint32_t* varId) {
const uint32_t ptrId = ip->GetSingleWordInOperand(
ip->opcode() == SpvOpStore ? kSpvStorePtrId : kSpvLoadPtrId);
ir::Instruction* ptrInst = def_use_mgr_->GetDef(ptrId);
*varId = IsNonPtrAccessChain(ptrInst->opcode()) ?
ptrInst->GetSingleWordInOperand(kSpvAccessChainPtrId) :
ptrId;
ir::Instruction* ip, uint32_t* varId) {
const SpvOp op = ip->opcode();
assert(op == SpvOpStore || op == SpvOpLoad);
*varId = ip->GetSingleWordInOperand(
op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
while (ptrInst->opcode() == SpvOpCopyObject) {
*varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
ptrInst = def_use_mgr_->GetDef(*varId);
}
ir::Instruction* varInst = ptrInst;
while (varInst->opcode() != SpvOpVariable) {
if (IsNonPtrAccessChain(varInst->opcode())) {
*varId = varInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
}
else {
assert(varInst->opcode() == SpvOpCopyObject);
*varId = varInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
}
varInst = def_use_mgr_->GetDef(*varId);
}
return ptrInst;
}
@ -89,13 +108,13 @@ bool LocalAccessChainConvertPass::IsTargetVar(uint32_t varId) {
return false;;
const uint32_t varTypeId = varInst->type_id();
const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
SpvStorageClassFunction) {
seen_non_target_vars_.insert(varId);
return false;
}
const uint32_t varPteTypeId =
varTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
if (!IsTargetType(varPteTypeInst)) {
seen_non_target_vars_.insert(varId);
@ -131,7 +150,7 @@ uint32_t LocalAccessChainConvertPass::GetPointeeTypeId(
const ir::Instruction* ptrInst) const {
const uint32_t ptrTypeId = ptrInst->type_id();
const ir::Instruction* ptrTypeInst = def_use_mgr_->GetDef(ptrTypeId);
return ptrTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
return ptrTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
}
void LocalAccessChainConvertPass::BuildAndAppendInst(
@ -152,7 +171,7 @@ uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
uint32_t* varPteTypeId,
std::vector<std::unique_ptr<ir::Instruction>>* newInsts) {
const uint32_t ldResultId = TakeNextId();
*varId = ptrInst->GetSingleWordInOperand(kSpvAccessChainPtrId);
*varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
const ir::Instruction* varInst = def_use_mgr_->GetDef(*varId);
assert(varInst->opcode() == SpvOpVariable);
*varPteTypeId = GetPointeeTypeId(varInst);
@ -168,7 +187,7 @@ void LocalAccessChainConvertPass::AppendConstantOperands(
ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t *iid) {
if (iidIdx > 0) {
const ir::Instruction* cInst = def_use_mgr_->GetDef(*iid);
uint32_t val = cInst->GetSingleWordInOperand(kSpvConstantValue);
uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
in_opnds->push_back(
{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
}
@ -246,13 +265,23 @@ void LocalAccessChainConvertPass::FindTargetVars(ir::Function* func) {
case SpvOpLoad: {
uint32_t varId;
ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
// For now, only convert non-ptr access chains
if (!IsNonPtrAccessChain(ptrInst->opcode()))
break;
// For now, only convert non-nested access chains
// TODO(): Convert nested access chains
if (!IsTargetVar(varId))
break;
// Rule out variables with non-non-ptr access chain refs
const SpvOp op = ptrInst->opcode();
if (!IsNonPtrAccessChain(op) && op != SpvOpVariable) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
break;
}
// Rule out variables with nested access chains
// TODO(): Convert nested access chains
if (IsNonPtrAccessChain(op) &&
ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx) != varId) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
break;
}
// Rule out variables accessed with non-constant indices
if (!IsConstantIndexAccessChain(ptrInst)) {
seen_non_target_vars_.insert(varId);
@ -299,7 +328,7 @@ bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) {
if (!IsTargetVar(varId))
break;
std::vector<std::unique_ptr<ir::Instruction>> newInsts;
uint32_t valId = ii->GetSingleWordInOperand(kSpvStoreValId);
uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx);
GenAccessChainStoreReplacement(ptrInst, valId, &newInsts);
def_use_mgr_->KillInst(&*ii);
DeleteIfUseless(ptrInst);
@ -341,13 +370,13 @@ Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
// TODO(): Handle non-32-bit integer constants in access chains
for (const ir::Instruction& inst : module_->types_values())
if (inst.opcode() == SpvOpTypeInt &&
inst.GetSingleWordInOperand(kSpvTypeIntWidth) != 32)
inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
return Status::SuccessWithoutChange;
// Process all entry point functions.
bool modified = false;
for (auto& e : module_->entry_points()) {
ir::Function* fn =
id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
modified = ConvertLocalAccessChains(fn) || modified;
}

View File

@ -14,20 +14,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iterator.h"
#include "local_single_block_elim_pass.h"
static const int kSpvEntryPointFunctionId = 1;
static const int kSpvStorePtrId = 0;
static const int kSpvStoreValId = 1;
static const int kSpvLoadPtrId = 0;
static const int kSpvAccessChainPtrId = 0;
static const int kSpvTypePointerStorageClass = 0;
static const int kSpvTypePointerTypeId = 1;
#include "iterator.h"
namespace spvtools {
namespace opt {
namespace {
const uint32_t kEntryPointFunctionIdInIdx = 1;
const uint32_t kStorePtrIdInIdx = 0;
const uint32_t kStoreValIdInIdx = 1;
const uint32_t kLoadPtrIdInIdx = 0;
const uint32_t kAccessChainPtrIdInIdx = 0;
const uint32_t kTypePointerStorageClassInIdx = 0;
const uint32_t kTypePointerTypeIdInIdx = 1;
const uint32_t kCopyObjectOperandInIdx = 0;
} // anonymous namespace
bool LocalSingleBlockLoadStoreElimPass::IsNonPtrAccessChain(
const SpvOp opcode) const {
return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
@ -67,12 +73,24 @@ bool LocalSingleBlockLoadStoreElimPass::IsTargetType(
ir::Instruction* LocalSingleBlockLoadStoreElimPass::GetPtr(
ir::Instruction* ip, uint32_t* varId) {
const SpvOp op = ip->opcode();
assert(op == SpvOpStore || op == SpvOpLoad);
*varId = ip->GetSingleWordInOperand(
ip->opcode() == SpvOpStore ? kSpvStorePtrId : kSpvLoadPtrId);
op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
while (ptrInst->opcode() == SpvOpCopyObject) {
*varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
ptrInst = def_use_mgr_->GetDef(*varId);
}
ir::Instruction* varInst = ptrInst;
while (IsNonPtrAccessChain(varInst->opcode())) {
*varId = varInst->GetSingleWordInOperand(kSpvAccessChainPtrId);
while (varInst->opcode() != SpvOpVariable) {
if (IsNonPtrAccessChain(varInst->opcode())) {
*varId = varInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
}
else {
assert(varInst->opcode() == SpvOpCopyObject);
*varId = varInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
}
varInst = def_use_mgr_->GetDef(*varId);
}
return ptrInst;
@ -87,13 +105,13 @@ bool LocalSingleBlockLoadStoreElimPass::IsTargetVar(uint32_t varId) {
assert(varInst->opcode() == SpvOpVariable);
const uint32_t varTypeId = varInst->type_id();
const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
SpvStorageClassFunction) {
seen_non_target_vars_.insert(varId);
return false;
}
const uint32_t varPteTypeId =
varTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
if (!IsTargetType(varPteTypeInst)) {
seen_non_target_vars_.insert(varId);
@ -137,7 +155,7 @@ bool LocalSingleBlockLoadStoreElimPass::IsLiveVar(uint32_t varId) const {
assert(varInst->opcode() == SpvOpVariable);
const uint32_t varTypeId = varInst->type_id();
const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
SpvStorageClassFunction)
return true;
// test if variable is loaded from
@ -202,12 +220,6 @@ void LocalSingleBlockLoadStoreElimPass::DCEInst(ir::Instruction* inst) {
bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
ir::Function* func) {
// Verify no CopyObject ops in function. This is a pre-SSA pass and
// is generally not useful for code already in CSSA form.
for (auto& blk : *func)
for (auto& inst : blk)
if (inst.opcode() == SpvOpCopyObject)
return false;
// Perform local store/load and load/load elimination on each block
bool modified = false;
for (auto bi = func->begin(); bi != func->end(); ++bi) {
@ -251,7 +263,7 @@ bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
if (ptrInst->opcode() == SpvOpVariable) {
auto si = var2store_.find(varId);
if (si != var2store_.end()) {
replId = si->second->GetSingleWordInOperand(kSpvStoreValId);
replId = si->second->GetSingleWordInOperand(kStoreValIdInIdx);
}
else {
auto li = var2load_.find(varId);
@ -324,7 +336,7 @@ Pass::Status LocalSingleBlockLoadStoreElimPass::ProcessImpl() {
// Call Mem2Reg on all remaining functions.
for (auto& e : module_->entry_points()) {
ir::Function* fn =
id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
modified = LocalSingleBlockLoadStoreElim(fn) || modified;
}
FinalizeNextId(module_);

View File

@ -20,20 +20,25 @@
#include "iterator.h"
#include "spirv/1.0/GLSL.std.450.h"
static const int kSpvEntryPointFunctionId = 1;
static const int kSpvStorePtrId = 0;
static const int kSpvStoreValId = 1;
static const int kSpvLoadPtrId = 0;
static const int kSpvAccessChainPtrId = 0;
static const int kSpvTypePointerStorageClass = 0;
static const int kSpvTypePointerTypeId = 1;
// Universal Limit of ResultID + 1
static const int kInvalidId = 0x400000;
namespace spvtools {
namespace opt {
namespace {
const uint32_t kEntryPointFunctionIdInIdx = 1;
const uint32_t kStorePtrIdInIdx = 0;
const uint32_t kStoreValIdInIdx = 1;
const uint32_t kLoadPtrIdInIdx = 0;
const uint32_t kAccessChainPtrIdInIdx = 0;
const uint32_t kTypePointerStorageClassInIdx = 0;
const uint32_t kTypePointerTypeIdInIdx = 1;
const uint32_t kCopyObjectOperandInIdx = 0;
} // anonymous namespace
bool LocalSingleStoreElimPass::IsNonPtrAccessChain(const SpvOp opcode) const {
return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
}
@ -72,12 +77,24 @@ bool LocalSingleStoreElimPass::IsTargetType(
ir::Instruction* LocalSingleStoreElimPass::GetPtr(
ir::Instruction* ip, uint32_t* varId) {
const SpvOp op = ip->opcode();
assert(op == SpvOpStore || op == SpvOpLoad);
*varId = ip->GetSingleWordInOperand(
ip->opcode() == SpvOpStore ? kSpvStorePtrId : kSpvLoadPtrId);
op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
while (ptrInst->opcode() == SpvOpCopyObject) {
*varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
ptrInst = def_use_mgr_->GetDef(*varId);
}
ir::Instruction* varInst = ptrInst;
while (IsNonPtrAccessChain(varInst->opcode())) {
*varId = varInst->GetSingleWordInOperand(kSpvAccessChainPtrId);
while (varInst->opcode() != SpvOpVariable) {
if (IsNonPtrAccessChain(varInst->opcode())) {
*varId = varInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
}
else {
assert(varInst->opcode() == SpvOpCopyObject);
*varId = varInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
}
varInst = def_use_mgr_->GetDef(*varId);
}
return ptrInst;
@ -92,13 +109,13 @@ bool LocalSingleStoreElimPass::IsTargetVar(uint32_t varId) {
assert(varInst->opcode() == SpvOpVariable);
const uint32_t varTypeId = varInst->type_id();
const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
SpvStorageClassFunction) {
seen_non_target_vars_.insert(varId);
return false;
}
const uint32_t varPteTypeId =
varTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
if (!IsTargetType(varPteTypeInst)) {
seen_non_target_vars_.insert(varId);
@ -145,7 +162,7 @@ void LocalSingleStoreElimPass::SingleStoreAnalyze(ir::Function* func) {
non_ssa_vars_.insert(varId);
continue;
}
if (IsNonPtrAccessChain(ptrInst->opcode())) {
if (ptrInst->opcode() != SpvOpVariable) {
non_ssa_vars_.insert(varId);
ssa_var2store_.erase(varId);
continue;
@ -272,8 +289,6 @@ bool LocalSingleStoreElimPass::SingleStoreProcess(ir::Function* func) {
uint32_t varId;
ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
// Skip access chain loads
if (IsNonPtrAccessChain(ptrInst->opcode()))
continue;
if (ptrInst->opcode() != SpvOpVariable)
continue;
const auto vsi = ssa_var2store_.find(varId);
@ -285,7 +300,7 @@ bool LocalSingleStoreElimPass::SingleStoreProcess(ir::Function* func) {
if (!Dominates(store2blk_[vsi->second], store2idx_[vsi->second], &*bi, instIdx))
continue;
// Use store value as replacement id
uint32_t replId = vsi->second->GetSingleWordInOperand(kSpvStoreValId);
uint32_t replId = vsi->second->GetSingleWordInOperand(kStoreValIdInIdx);
// replace all instances of the load's id with the SSA value's id
ReplaceAndDeleteLoad(&*ii, replId);
modified = true;
@ -318,7 +333,7 @@ bool LocalSingleStoreElimPass::IsLiveVar(uint32_t varId) const {
assert(varInst->opcode() == SpvOpVariable);
const uint32_t varTypeId = varInst->type_id();
const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
SpvStorageClassFunction)
return true;
// test if variable is loaded from
@ -444,7 +459,7 @@ Pass::Status LocalSingleStoreElimPass::ProcessImpl() {
// Call Mem2Reg on all remaining functions.
for (auto& e : module_->entry_points()) {
ir::Function* fn =
id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
modified = LocalSingleStoreElim(fn) || modified;
}
FinalizeNextId(module_);

View File

@ -82,6 +82,10 @@ ir::Instruction* LocalMultiStoreElimPass::GetPtr(
*varId = ip->GetSingleWordInOperand(
op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
while (ptrInst->opcode() == SpvOpCopyObject) {
*varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
ptrInst = def_use_mgr_->GetDef(*varId);
}
ir::Instruction* varInst = ptrInst;
while (varInst->opcode() != SpvOpVariable) {
if (IsNonPtrAccessChain(varInst->opcode())) {

View File

@ -391,7 +391,7 @@ OpFunctionEnd
assembly, assembly, false, true);
}
TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfCopyObjectInFunction) {
TEST_F(LocalSingleBlockLoadStoreElimTest, ElimIfCopyObjectInFunction) {
// Note: SPIR-V hand edited to insert CopyObject
//
// #version 140
@ -406,7 +406,7 @@ TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfCopyObjectInFunction) {
// gl_FragData[1] = v2;
// }
const std::string assembly =
const std::string predefs =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@ -435,7 +435,10 @@ OpName %v2 "v2"
%_ptr_Output_v4float = OpTypePointer Output %v4float
%float_0_5 = OpConstant %float 0.5
%int_1 = OpConstant %int 1
%main = OpFunction %void None %8
)";
const std::string before =
R"(%main = OpFunction %void None %8
%22 = OpLabel
%v1 = OpVariable %_ptr_Function_v4float Function
%v2 = OpVariable %_ptr_Function_v4float Function
@ -453,10 +456,26 @@ OpStore %28 %27
OpStore %30 %29
OpReturn
OpFunctionEnd
)";
const std::string after =
R"(%main = OpFunction %void None %8
%22 = OpLabel
%v1 = OpVariable %_ptr_Function_v4float Function
%v2 = OpVariable %_ptr_Function_v4float Function
%23 = OpLoad %v4float %BaseColor
%25 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0
OpStore %25 %23
%26 = OpLoad %v4float %BaseColor
%27 = OpVectorTimesScalar %v4float %26 %float_0_5
%30 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1
OpStore %30 %27
OpReturn
OpFunctionEnd
)";
SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
assembly, assembly, false, true);
predefs + before, predefs + after, true, true);
}
// TODO(greg-lunarg): Add tests to verify handling of these cases: