mirror of
https://gitee.com/openharmony/third_party_spirv-tools
synced 2024-11-23 07:20:28 +00:00
Fold multiply and subtraction into FMA with negation (#4808)
This change adds a folding rule which transforms x * y - a and a - x * y into FMA(x, y, -a) and FMA(-x, y, a), respectively. While the SPIR-V instruction count remains the same, target instruction sets typically feature FMA instruction variants that can negate an operand. Also this transformation may unlock further optimizations which eliminate the negation. (Google bug: b/226145988)
This commit is contained in:
parent
82d91083cb
commit
130a05d2e3
@ -1488,6 +1488,74 @@ bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
|
||||
// negated if |negate_addition| is true, otherwise |x| gets negated.
|
||||
void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
|
||||
uint32_t a, bool negate_addition) {
|
||||
uint32_t ext =
|
||||
sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
|
||||
|
||||
if (ext == 0) {
|
||||
sub->context()->AddExtInstImport("GLSL.std.450");
|
||||
ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
|
||||
assert(ext != 0 &&
|
||||
"Could not add the GLSL.std.450 extended instruction set");
|
||||
}
|
||||
|
||||
InstructionBuilder ir_builder(
|
||||
sub->context(), sub,
|
||||
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
||||
|
||||
Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), SpvOpFNegate,
|
||||
negate_addition ? a : x);
|
||||
uint32_t neg_op = neg->result_id(); // -a : -x
|
||||
|
||||
std::vector<Operand> operands;
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
|
||||
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
|
||||
|
||||
sub->SetOpcode(SpvOpExtInst);
|
||||
sub->SetInOperands(std::move(operands));
|
||||
}
|
||||
|
||||
// Folds a multiply and subtract into an Fma and negation.
|
||||
//
|
||||
// Cases:
|
||||
// (x * y) - a = Fma x y -a
|
||||
// a - (x * y) = Fma -x y a
|
||||
bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
|
||||
const std::vector<const analysis::Constant*>&) {
|
||||
assert(sub->opcode() == SpvOpFSub);
|
||||
|
||||
if (!sub->IsFloatingPointFoldingAllowed()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||
for (int i = 0; i < 2; i++) {
|
||||
uint32_t op_id = sub->GetSingleWordInOperand(i);
|
||||
Instruction* mul = def_use_mgr->GetDef(op_id);
|
||||
|
||||
if (mul->opcode() != SpvOpFMul) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!mul->IsFloatingPointFoldingAllowed()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t x = mul->GetSingleWordInOperand(0);
|
||||
uint32_t y = mul->GetSingleWordInOperand(1);
|
||||
uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
|
||||
ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
FoldingRule IntMultipleBy1() {
|
||||
return [](IRContext*, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants) {
|
||||
@ -2831,6 +2899,7 @@ void FoldingRules::AddFoldingRules() {
|
||||
rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
|
||||
rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
|
||||
rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
|
||||
rules_[SpvOpFSub].push_back(MergeMulSubArithmetic);
|
||||
|
||||
rules_[SpvOpIAdd].push_back(RedundantIAdd());
|
||||
rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
|
||||
|
@ -7359,7 +7359,7 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test 5: that the OpExtInstImport instruction is generated if it is missing.
|
||||
// Test 4: that the OpExtInstImport instruction is generated if it is missing.
|
||||
InstructionFoldingCase<bool>(
|
||||
std::string() +
|
||||
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
|
||||
@ -7454,7 +7454,63 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, false)
|
||||
3, false),
|
||||
// Test case 7: (x * y) - a = Fma(x, y, -a)
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
|
||||
"; CHECK: OpFunction\n" +
|
||||
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
|
||||
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
|
||||
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
|
||||
"; CHECK: [[na:%\\w+]] = OpFNegate {{%\\w+}} [[la]]\n" +
|
||||
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[na]]\n" +
|
||||
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_float Function\n" +
|
||||
"%y = OpVariable %_ptr_float Function\n" +
|
||||
"%a = OpVariable %_ptr_float Function\n" +
|
||||
"%lx = OpLoad %float %x\n" +
|
||||
"%ly = OpLoad %float %y\n" +
|
||||
"%mul = OpFMul %float %lx %ly\n" +
|
||||
"%la = OpLoad %float %a\n" +
|
||||
"%3 = OpFSub %float %mul %la\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 8: a - (x * y) = Fma(-x, y, a)
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
|
||||
"; CHECK: OpFunction\n" +
|
||||
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
|
||||
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
|
||||
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
|
||||
"; CHECK: [[nx:%\\w+]] = OpFNegate {{%\\w+}} [[lx]]\n" +
|
||||
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[nx]] [[ly]] [[la]]\n" +
|
||||
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_float Function\n" +
|
||||
"%y = OpVariable %_ptr_float Function\n" +
|
||||
"%a = OpVariable %_ptr_float Function\n" +
|
||||
"%lx = OpLoad %float %x\n" +
|
||||
"%ly = OpLoad %float %y\n" +
|
||||
"%mul = OpFMul %float %lx %ly\n" +
|
||||
"%la = OpLoad %float %a\n" +
|
||||
"%3 = OpFSub %float %la %mul\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true)
|
||||
));
|
||||
|
||||
using MatchingInstructionWithNoResultFoldingTest =
|
||||
|
Loading…
Reference in New Issue
Block a user