[MLIR][Affine] Simplify nested modulo operations when able

It is the case that, for all positive a and b such that b divides a
(e mod (a * b)) mod b = e mod b. For example, ((d0 mod 35) mod 5) can
be simplified to (d0 mod 5), but ((d0 mod 35) mod 4) cannot be simplified
further (x = 36 is a counterexample).

This change enables more complex simplifications. For example,
((d0 * 72 + d1) mod 144) mod 9 can now simplify to (d0 * 72 + d1) mod 9
and thus to d1 mod 9. Expressions with chained modulus operators are
reasonably common in tensor applications, and this change _should_
improve code generation for such expressions.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D109930
This commit is contained in:
Krzysztof Drewniak 2021-09-16 21:25:20 +00:00
parent 08f0cb7719
commit 121aab84d1
4 changed files with 28 additions and 13 deletions

View File

@ -829,6 +829,15 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
return lBin.getLHS() % rhsConst.getValue();
}
// Simplify (e % a) % b to e % b when b evenly divides a
if (lBin && lBin.getKind() == AffineExprKind::Mod) {
auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
if (intermediate && intermediate.getValue() >= 1 &&
mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
return lBin.getLHS() % rhsConst.getValue();
}
}
return nullptr;
}

View File

@ -189,6 +189,9 @@
// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 * 3, (d0 + d1) * 2, d0 mod 2)>
#map58 = affine_map<(d0, d1) -> (4*d0 - 2*d0 + d0, (d0 + d1) + (d0 + d1), 2 * (d0 mod 2) - d0 mod 2)>
// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 mod 5, (d1 mod 35) mod 4)>
#map59 = affine_map<(d0, d1) -> ((d0 mod 35) mod 5, (d1 mod 35) mod 4)>
// Single identity maps are removed.
// CHECK: @f0(memref<2x4xi8, 1>)
func private @f0(memref<2x4xi8, #map0, 1>)
@ -373,3 +376,6 @@ func private @f56(memref<1x1xi8, #map56>)
// CHECK: "f58"() {map = #map{{[0-9]+}}} : () -> ()
"f58"() {map = #map58} : () -> ()
// CHECK: "f59"() {map = #map{{[0-9]+}}} : () -> ()
"f59"() {map = #map59} : () -> ()

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL
// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
// Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir
// Part IV of fusion tests in mlir/test/Transforms/loop-fusion-4.mlir
@ -576,9 +576,9 @@ func @fuse_across_varying_dims_complex(%arg0: f32) {
}
// MAXIMAL-DAG: [[$MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 72 + d1) floordiv 2304)>
// MAXIMAL-DAG: [[$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 72 + d1) mod 2304) floordiv 1152)>
// MAXIMAL-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) floordiv 9) floordiv 8)>
// MAXIMAL-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) floordiv 3)>
// MAXIMAL-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) mod 3)>
// MAXIMAL-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> ((((d0 * 72 + d1) mod 1152) floordiv 9) floordiv 8)>
// MAXIMAL-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> ((d1 mod 9) floordiv 3)>
// MAXIMAL-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (d1 mod 3)>
// MAXIMAL-DAG: [[$MAP7:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
// MAXIMAL-DAG: [[$MAP8:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 16 - d1 + 15)>
// MAXIMAL-LABEL: func @fuse_across_varying_dims_complex

View File

@ -1,6 +1,6 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s
// Part II of fusion tests in mlir/test/Transforms/loop-fusion=2.mlir.
// Part II of fusion tests in mlir/test/Transforms/loop-fusion=2.mlir.
// Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir
// Part IV of fusion tests in mlir/test/Transforms/loop-fusion-4.mlir
@ -737,15 +737,15 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
//
// CHECK-DAG: [[$MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 9 + d1) floordiv 288)>
// CHECK-DAG: [[$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 288) floordiv 144)>
// CHECK-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> ((((d0 * 9 + d1) mod 288) mod 144) floordiv 48)>
// CHECK-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) floordiv 16)>
// CHECK-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) mod 16)>
// CHECK-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 144) floordiv 48)>
// CHECK-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 48) floordiv 16)>
// CHECK-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 9 + d1) mod 16)>
// CHECK-DAG: [[$MAP11:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 9 + d1)>
// CHECK-DAG: [[$MAP12:#map[0-9]+]] = affine_map<(d0) -> (d0 floordiv 288)>
// CHECK-DAG: [[$MAP13:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 288) floordiv 144)>
// CHECK-DAG: [[$MAP14:#map[0-9]+]] = affine_map<(d0) -> (((d0 mod 288) mod 144) floordiv 48)>
// CHECK-DAG: [[$MAP15:#map[0-9]+]] = affine_map<(d0) -> ((((d0 mod 288) mod 144) mod 48) floordiv 16)>
// CHECK-DAG: [[$MAP16:#map[0-9]+]] = affine_map<(d0) -> ((((d0 mod 288) mod 144) mod 48) mod 16)>
// CHECK-DAG: [[$MAP14:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 144) floordiv 48)>
// CHECK-DAG: [[$MAP15:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 48) floordiv 16)>
// CHECK-DAG: [[$MAP16:#map[0-9]+]] = affine_map<(d0) -> (d0 mod 16)>
// CHECK-DAG: [[$MAP17:#map[0-9]+]] = affine_map<(d0) -> (0)>
//
@ -761,7 +761,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
// CHECK-NEXT: affine.apply [[$MAP3]](%{{.*}}, %{{.*}})
// CHECK-NEXT: affine.apply [[$MAP4]](%{{.*}}, %{{.*}})
// CHECK-NEXT: "foo"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (index, index, index, index, index, index) -> i32
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, (((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) floordiv 48, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) floordiv 16, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT: affine.apply [[$MAP11]](%{{.*}}, %{{.*}})
// CHECK-NEXT: affine.apply [[$MAP12]](%{{.*}})
// CHECK-NEXT: affine.apply [[$MAP13]](%{{.*}})
@ -769,7 +769,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
// CHECK-NEXT: affine.apply [[$MAP15]](%{{.*}})
// CHECK-NEXT: affine.apply [[$MAP16]](%{{.*}})
// CHECK-NEXT: affine.apply [[$MAP17]](%{{.*}})
// CHECK-NEXT: affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, (((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) floordiv 48, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) floordiv 16, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT: affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xi32>
// CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xi32>
// CHECK-NEXT: muli %{{.*}}, %{{.*}} : i32