mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-05 15:42:24 +00:00
[flang][hlfir] allow recursive intrinsic lowering
We need to allow recursive application of intrinsic lowering patterns, otherwise we cannot lower nested calls of the same intrinsic e.g. matmul(matmul(a, b), c). matmul(matmul(a, b), matmul(c, d)) requires hlfir.associate of hlfir expr with more than one use (TODO). Differential Revision: https://reviews.llvm.org/D152284
This commit is contained in:
parent
631c965483
commit
7c8ef818f8
@ -1,4 +1,4 @@
|
||||
//===- LowerHLFIRIntrinsics.cpp - Bufferize HLFIR ------------------------===//
|
||||
//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@ -37,7 +37,23 @@ namespace {
|
||||
/// runtime calls
|
||||
template <class OP>
|
||||
class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
|
||||
using mlir::OpRewritePattern<OP>::OpRewritePattern;
|
||||
public:
|
||||
explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
|
||||
: mlir::OpRewritePattern<OP>{ctx} {
|
||||
// required for cases where intrinsics are chained together e.g.
|
||||
// matmul(matmul(a, b), c)
|
||||
// because converting the inner operation then invalidates the
|
||||
// outer operation: causing the pattern to apply recursively.
|
||||
//
|
||||
// This is safe because we always progress with each iteration. Circular
|
||||
// applications of operations are not expressible in MLIR because we use
|
||||
// an SSA form and one must become first. E.g.
|
||||
// %a = hlfir.matmul %b %d
|
||||
// %b = hlfir.matmul %a %d
|
||||
// cannot be written.
|
||||
// MSVC needs the this->
|
||||
this->setHasBoundedRewriteRecursion(true);
|
||||
}
|
||||
|
||||
protected:
|
||||
struct IntrinsicArgument {
|
||||
|
@ -43,3 +43,39 @@ func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lh
|
||||
// CHECK: hlfir.destroy %[[ASEXPR]]
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// nested matmuls leading to recursive pattern application
|
||||
func.func @_QPtest(%arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"}, %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"}, %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%c3_0 = arith.constant 3 : index
|
||||
%0 = fir.shape %c3, %c3_0 : (index, index) -> !fir.shape<2>
|
||||
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
|
||||
%c3_1 = arith.constant 3 : index
|
||||
%c3_2 = arith.constant 3 : index
|
||||
%2 = fir.shape %c3_1, %c3_2 : (index, index) -> !fir.shape<2>
|
||||
%3:2 = hlfir.declare %arg1(%2) {uniq_name = "_QFtestEb"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
|
||||
%c3_3 = arith.constant 3 : index
|
||||
%c3_4 = arith.constant 3 : index
|
||||
%4 = fir.shape %c3_3, %c3_4 : (index, index) -> !fir.shape<2>
|
||||
%5:2 = hlfir.declare %arg2(%4) {uniq_name = "_QFtestEc"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
|
||||
%c3_5 = arith.constant 3 : index
|
||||
%c3_6 = arith.constant 3 : index
|
||||
%6 = fir.shape %c3_5, %c3_6 : (index, index) -> !fir.shape<2>
|
||||
%7:2 = hlfir.declare %arg3(%6) {uniq_name = "_QFtestEout"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
|
||||
%8 = hlfir.matmul %1#0 %3#0 {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
|
||||
%9 = hlfir.matmul %8 %5#0 {fastmath = #arith.fastmath<contract>} : (!hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
|
||||
hlfir.assign %9 to %7#0 : !hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>
|
||||
hlfir.destroy %9 : !hlfir.expr<3x3xf32>
|
||||
hlfir.destroy %8 : !hlfir.expr<3x3xf32>
|
||||
return
|
||||
}
|
||||
// just check that we apply the patterns successfully. The details are checked above
|
||||
// CHECK-LABEL: func.func @_QPtest(
|
||||
// CHECK: %arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"},
|
||||
// CHECK-SAME: %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"},
|
||||
// CHECK-SAME: %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"},
|
||||
// CHECK-SAME: %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
|
||||
// CHECK: fir.call @_FortranAMatmul(
|
||||
// CHECK; fir.call @_FortranAMatmul(%40, %41, %42, %43, %c20_i32) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
|
||||
// CHECK: return
|
||||
// CHECK-NEXT: }
|
||||
|
Loading…
Reference in New Issue
Block a user