Fold multiply and subtraction into FMA with negation (#4808)

This change adds a folding rule which transforms x * y - a and a - x * y
into FMA(x, y, -a) and FMA(-x, y, a), respectively.

While the SPIR-V instruction count remains the same, target instruction
sets typically feature FMA instruction variants that can negate an
operand. Also this transformation may unlock further optimizations which
eliminate the negation.

(Google bug: b/226145988)
This commit is contained in:
Nicolas Capens 2022-05-31 12:03:56 -04:00 committed by GitHub
parent 82d91083cb
commit 130a05d2e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 2 deletions

View File

@ -1488,6 +1488,74 @@ bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
return false;
}
// Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
// negated if |negate_addition| is true, otherwise |x| gets negated.
void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
uint32_t a, bool negate_addition) {
uint32_t ext =
sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (ext == 0) {
sub->context()->AddExtInstImport("GLSL.std.450");
ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
assert(ext != 0 &&
"Could not add the GLSL.std.450 extended instruction set");
}
InstructionBuilder ir_builder(
sub->context(), sub,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), SpvOpFNegate,
negate_addition ? a : x);
uint32_t neg_op = neg->result_id(); // -a : -x
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
sub->SetOpcode(SpvOpExtInst);
sub->SetInOperands(std::move(operands));
}
// Folds a multiply and subtract into an Fma and negation.
//
// Cases:
// (x * y) - a = Fma x y -a
// a - (x * y) = Fma -x y a
bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
const std::vector<const analysis::Constant*>&) {
assert(sub->opcode() == SpvOpFSub);
if (!sub->IsFloatingPointFoldingAllowed()) {
return false;
}
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
for (int i = 0; i < 2; i++) {
uint32_t op_id = sub->GetSingleWordInOperand(i);
Instruction* mul = def_use_mgr->GetDef(op_id);
if (mul->opcode() != SpvOpFMul) {
continue;
}
if (!mul->IsFloatingPointFoldingAllowed()) {
continue;
}
uint32_t x = mul->GetSingleWordInOperand(0);
uint32_t y = mul->GetSingleWordInOperand(1);
uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
return true;
}
return false;
}
FoldingRule IntMultipleBy1() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
@ -2831,6 +2899,7 @@ void FoldingRules::AddFoldingRules() {
rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
rules_[SpvOpFSub].push_back(MergeMulSubArithmetic);
rules_[SpvOpIAdd].push_back(RedundantIAdd());
rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());

View File

@ -7359,7 +7359,7 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 5: that the OpExtInstImport instruction is generated if it is missing.
// Test 4: that the OpExtInstImport instruction is generated if it is missing.
InstructionFoldingCase<bool>(
std::string() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
@ -7454,7 +7454,63 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, false)
3, false),
// Test case 7: (x * y) - a = Fma(x, y, -a)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[na:%\\w+]] = OpFNegate {{%\\w+}} [[la]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[na]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFSub %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 8: a - (x * y) = Fma(-x, y, a)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[nx:%\\w+]] = OpFNegate {{%\\w+}} [[lx]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[nx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFSub %float %la %mul\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true)
));
using MatchingInstructionWithNoResultFoldingTest =