[IR] Improve member ShuffleVectorInst::isReplicationMask()

When we have an actual shuffle, we can impose the additional restriction
that the mask replicates the elements of the first operand, so we know
the replication factor as a ratio of output and op0 vector sizes.
This commit is contained in:
Roman Lebedev 2021-11-05 19:11:55 +03:00
parent 6d48e2505c
commit a5cd27880a
No known key found for this signature in database
GPG Key ID: 083C3EBB4A1689E0
3 changed files with 26 additions and 8 deletions

View File

@ -2373,14 +2373,7 @@ public:
}
/// Return true if this shuffle mask is a replication mask.
bool isReplicationMask(int &ReplicationFactor, int &VF) const {
// Not possible to express a shuffle mask for a scalable vector for this
// case.
if (isa<ScalableVectorType>(getType()))
return false;
return isReplicationMask(ShuffleMask, ReplicationFactor, VF);
}
bool isReplicationMask(int &ReplicationFactor, int &VF) const;
/// Change values in a shuffle permute mask assuming the two vector operands
/// of length InVecNumElts have swapped position.

View File

@ -2502,6 +2502,21 @@ bool ShuffleVectorInst::isReplicationMask(ArrayRef<int> Mask,
return false;
}
bool ShuffleVectorInst::isReplicationMask(int &ReplicationFactor,
int &VF) const {
// Not possible to express a shuffle mask for a scalable vector for this
// case.
if (isa<ScalableVectorType>(getType()))
return false;
VF = cast<FixedVectorType>(Op<0>()->getType())->getNumElements();
if (ShuffleMask.size() % VF != 0)
return false;
ReplicationFactor = ShuffleMask.size() / VF;
return isReplicationMaskWithParams(ShuffleMask, ReplicationFactor, VF);
}
//===----------------------------------------------------------------------===//
// InsertValueInst Class
//===----------------------------------------------------------------------===//

View File

@ -1126,6 +1126,16 @@ TEST(InstructionsTest, ShuffleMaskIsReplicationMask) {
ReplicatedMask, GuessedReplicationFactor, GuessedVF));
EXPECT_EQ(GuessedReplicationFactor, ReplicationFactor);
EXPECT_EQ(GuessedVF, VF);
for (int OpVF : seq_inclusive(VF, 2 * VF + 1)) {
LLVMContext Ctx;
Type *OpVFTy = FixedVectorType::get(IntegerType::getInt1Ty(Ctx), OpVF);
Value *Op = ConstantVector::getNullValue(OpVFTy);
ShuffleVectorInst *SVI = new ShuffleVectorInst(Op, Op, ReplicatedMask);
EXPECT_EQ(SVI->isReplicationMask(GuessedReplicationFactor, GuessedVF),
OpVF == VF);
delete SVI;
}
}
}
}