[GlobalISel] Add matchers for constant splat.

This change exposes isBuildVectorConstantSplat() to the llvm namespace
and uses it to implement the constant splat versions of
m_SpecificICst().

CombinerHelper::matchOrShiftToFunnelShift() can now work with vector
types and CombinerHelper::matchMulOBy2()'s match for a constant splat is
simplified.

Differential Revision: https://reviews.llvm.org/D114625
This commit is contained in:
Abinav Puthan Purayil 2021-11-26 16:31:37 +05:30
parent 5d602120c3
commit bc5dbb0bae
7 changed files with 188 additions and 30 deletions

View File

@ -129,6 +129,43 @@ inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) {
return SpecificConstantMatch(RequestedValue);
}
/// Matcher for a specific constant splat.
struct SpecificConstantSplatMatch {
int64_t RequestedVal;
SpecificConstantSplatMatch(int64_t RequestedVal)
: RequestedVal(RequestedVal) {}
bool match(const MachineRegisterInfo &MRI, Register Reg) {
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
/* AllowUndef */ false);
}
};
/// Matches a constant splat of \p RequestedValue.
inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) {
return SpecificConstantSplatMatch(RequestedValue);
}
/// Matcher for a specific constant or constant splat.
struct SpecificConstantOrSplatMatch {
int64_t RequestedVal;
SpecificConstantOrSplatMatch(int64_t RequestedVal)
: RequestedVal(RequestedVal) {}
bool match(const MachineRegisterInfo &MRI, Register Reg) {
int64_t MatchedVal;
if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal)
return true;
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
/* AllowUndef */ false);
}
};
/// Matches a \p RequestedValue constant or a constant splat of \p
/// RequestedValue.
inline SpecificConstantOrSplatMatch
m_SpecificICstOrSplat(int64_t RequestedValue) {
return SpecificConstantOrSplatMatch(RequestedValue);
}
///{
/// Convenience matchers for specific integer values.
inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); }

View File

@ -378,6 +378,18 @@ Optional<FPValueAndVReg> getFConstantSplat(Register VReg,
const MachineRegisterInfo &MRI,
bool AllowUndef = true);
/// Return true if the specified register is defined by G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
bool isBuildVectorConstantSplat(const Register Reg,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef);
/// Return true if the specified instruction is a G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
bool isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef);
/// Return true if the specified instruction is a G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef.
bool isBuildVectorAllZeros(const MachineInstr &MI,

View File

@ -3878,21 +3878,21 @@ bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
Register ShlSrc, ShlAmt, LShrSrc, LShrAmt;
unsigned FshOpc = 0;
// TODO: Handle vector types.
// Match (or (shl x, amt), (lshr y, sub(bw, amt))).
if (mi_match(Dst, MRI,
// m_GOr() handles the commuted version as well.
m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)),
m_GLShr(m_Reg(LShrSrc), m_GSub(m_SpecificICst(BitWidth),
m_Reg(LShrAmt)))))) {
if (mi_match(
Dst, MRI,
// m_GOr() handles the commuted version as well.
m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)),
m_GLShr(m_Reg(LShrSrc), m_GSub(m_SpecificICstOrSplat(BitWidth),
m_Reg(LShrAmt)))))) {
FshOpc = TargetOpcode::G_FSHL;
// Match (or (shl x, sub(bw, amt)), (lshr y, amt)).
} else if (mi_match(
Dst, MRI,
m_GOr(m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt)),
m_GShl(m_Reg(ShlSrc), m_GSub(m_SpecificICst(BitWidth),
m_Reg(ShlAmt)))))) {
} else if (mi_match(Dst, MRI,
m_GOr(m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt)),
m_GShl(m_Reg(ShlSrc),
m_GSub(m_SpecificICstOrSplat(BitWidth),
m_Reg(ShlAmt)))))) {
FshOpc = TargetOpcode::G_FSHR;
} else {
@ -4543,20 +4543,9 @@ bool CombinerHelper::matchNarrowBinopFeedingAnd(
bool CombinerHelper::matchMulOBy2(MachineInstr &MI, BuildFnTy &MatchInfo) {
unsigned Opc = MI.getOpcode();
assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
// Check for a constant 2 or a splat of 2 on the RHS.
auto RHS = MI.getOperand(3).getReg();
bool IsVector = MRI.getType(RHS).isVector();
if (!IsVector && !mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICst(2)))
if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(2)))
return false;
if (IsVector) {
// FIXME: There's no mi_match pattern for this yet.
auto *RHSDef = getDefIgnoringCopies(RHS, MRI);
if (!RHSDef)
return false;
auto Splat = getBuildVectorConstantSplat(*RHSDef, MRI);
if (!Splat || *Splat != 2)
return false;
}
MatchInfo = [=, &MI](MachineIRBuilder &B) {
Observer.changingInstr(MI);

View File

@ -1030,16 +1030,22 @@ Optional<ValueAndVReg> getAnyConstantSplat(Register VReg,
return SplatValAndReg;
}
bool isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef) {
if (auto SplatValAndReg =
getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, AllowUndef))
} // end anonymous namespace
bool llvm::isBuildVectorConstantSplat(const Register Reg,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef) {
if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef))
return mi_match(SplatValAndReg->VReg, MRI, m_SpecificICst(SplatValue));
return false;
}
} // end anonymous namespace
bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef) {
return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
AllowUndef);
}
Optional<int64_t>
llvm::getBuildVectorConstantSplat(const MachineInstr &MI,

View File

@ -27,6 +27,33 @@ body: |
$vgpr3 = COPY %or
...
---
name: fshl_v2i32
tracksRegLiveness: true
body: |
bb.0:
liveins: $vgpr0_vgpr1, $vgpr2_vgpr3, $vgpr4_vgpr5, $vgpr6_vgpr7
; CHECK-LABEL: name: fshl_v2i32
; CHECK: liveins: $vgpr0_vgpr1, $vgpr2_vgpr3, $vgpr4_vgpr5, $vgpr6_vgpr7
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
; CHECK-NEXT: %b:_(<2 x s32>) = COPY $vgpr2_vgpr3
; CHECK-NEXT: %amt:_(<2 x s32>) = COPY $vgpr4_vgpr5
; CHECK-NEXT: %or:_(<2 x s32>) = G_FSHL %a, %b, %amt(<2 x s32>)
; CHECK-NEXT: $vgpr6_vgpr7 = COPY %or(<2 x s32>)
%a:_(<2 x s32>) = COPY $vgpr0_vgpr1
%b:_(<2 x s32>) = COPY $vgpr2_vgpr3
%amt:_(<2 x s32>) = COPY $vgpr4_vgpr5
%scalar_bw:_(s32) = G_CONSTANT i32 32
%bw:_(<2 x s32>) = G_BUILD_VECTOR %scalar_bw(s32), %scalar_bw(s32)
%shl:_(<2 x s32>) = G_SHL %a:_, %amt:_(<2 x s32>)
%sub:_(<2 x s32>) = G_SUB %bw:_, %amt:_
%lshr:_(<2 x s32>) = G_LSHR %b:_, %sub:_(<2 x s32>)
%or:_(<2 x s32>) = G_OR %shl:_, %lshr:_
$vgpr6_vgpr7 = COPY %or
...
---
name: fshl_commute_i32
tracksRegLiveness: true

View File

@ -25,6 +25,31 @@ body: |
$vgpr2 = COPY %or
...
---
name: rotl_v2i32
tracksRegLiveness: true
body: |
bb.0:
liveins: $vgpr0_vgpr1, $vgpr2_vgpr3, $vgpr4_vgpr5
; CHECK-LABEL: name: rotl_v2i32
; CHECK: liveins: $vgpr0_vgpr1, $vgpr2_vgpr3, $vgpr4_vgpr5
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
; CHECK-NEXT: %amt:_(<2 x s32>) = COPY $vgpr2_vgpr3
; CHECK-NEXT: %or:_(<2 x s32>) = G_ROTL %a, %amt(<2 x s32>)
; CHECK-NEXT: $vgpr4_vgpr5 = COPY %or(<2 x s32>)
%a:_(<2 x s32>) = COPY $vgpr0_vgpr1
%amt:_(<2 x s32>) = COPY $vgpr2_vgpr3
%scalar_bw:_(s32) = G_CONSTANT i32 32
%bw:_(<2 x s32>) = G_BUILD_VECTOR %scalar_bw(s32), %scalar_bw(s32)
%shl:_(<2 x s32>) = G_SHL %a:_, %amt:_(<2 x s32>)
%sub:_(<2 x s32>) = G_SUB %bw:_, %amt:_
%lshr:_(<2 x s32>) = G_LSHR %a:_, %sub:_(<2 x s32>)
%or:_(<2 x s32>) = G_OR %shl:_, %lshr:_
$vgpr4_vgpr5 = COPY %or
...
---
name: rotl_commute_i32
tracksRegLiveness: true
@ -55,6 +80,7 @@ tracksRegLiveness: true
body: |
bb.0:
liveins: $vgpr0, $vgpr1, $vgpr2
; CHECK-LABEL: name: rotr_i32
; CHECK: liveins: $vgpr0, $vgpr1, $vgpr2
; CHECK-NEXT: {{ $}}

View File

@ -533,6 +533,67 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstant) {
EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42)));
}
TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
setUp();
if (!TM)
return;
LLT s64 = LLT::scalar(64);
LLT v4s64 = LLT::fixed_vector(4, s64);
MachineInstrBuilder FortyTwoSplat =
B.buildSplatVector(v4s64, B.buildConstant(s64, 42));
MachineInstrBuilder FortyTwo = B.buildConstant(s64, 42);
EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(42)));
EXPECT_FALSE(
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43)));
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42)));
MachineInstrBuilder NonConstantSplat =
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
MachineInstrBuilder AddSplat =
B.buildAdd(v4s64, NonConstantSplat, FortyTwoSplat);
EXPECT_TRUE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(42)));
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43)));
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42)));
MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42)));
}
TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
setUp();
if (!TM)
return;
LLT s64 = LLT::scalar(64);
LLT v4s64 = LLT::fixed_vector(4, s64);
MachineInstrBuilder FortyTwoSplat =
B.buildSplatVector(v4s64, B.buildConstant(s64, 42));
MachineInstrBuilder FortyTwo = B.buildConstant(s64, 42);
EXPECT_TRUE(
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(42)));
EXPECT_FALSE(
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43)));
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42)));
MachineInstrBuilder NonConstantSplat =
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
MachineInstrBuilder AddSplat =
B.buildAdd(v4s64, NonConstantSplat, FortyTwoSplat);
EXPECT_TRUE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43)));
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42)));
MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
}
TEST_F(AArch64GISelMITest, MatchZeroInt) {
setUp();
if (!TM)