diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 0364b3b2b96e..2e13d6ae62c4 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -829,6 +829,15 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { return lBin.getLHS() % rhsConst.getValue(); } + // Simplify (e % a) % b to e % b when b evenly divides a + if (lBin && lBin.getKind() == AffineExprKind::Mod) { + auto intermediate = lBin.getRHS().dyn_cast(); + if (intermediate && intermediate.getValue() >= 1 && + mod(intermediate.getValue(), rhsConst.getValue()) == 0) { + return lBin.getLHS() % rhsConst.getValue(); + } + } + return nullptr; } diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir index 3e3e2c3fe6f6..414741dab38f 100644 --- a/mlir/test/IR/affine-map.mlir +++ b/mlir/test/IR/affine-map.mlir @@ -189,6 +189,9 @@ // CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 * 3, (d0 + d1) * 2, d0 mod 2)> #map58 = affine_map<(d0, d1) -> (4*d0 - 2*d0 + d0, (d0 + d1) + (d0 + d1), 2 * (d0 mod 2) - d0 mod 2)> +// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 mod 5, (d1 mod 35) mod 4)> +#map59 = affine_map<(d0, d1) -> ((d0 mod 35) mod 5, (d1 mod 35) mod 4)> + // Single identity maps are removed. // CHECK: @f0(memref<2x4xi8, 1>) func private @f0(memref<2x4xi8, #map0, 1>) @@ -373,3 +376,6 @@ func private @f56(memref<1x1xi8, #map56>) // CHECK: "f58"() {map = #map{{[0-9]+}}} : () -> () "f58"() {map = #map58} : () -> () + +// CHECK: "f59"() {map = #map{{[0-9]+}}} : () -> () +"f59"() {map = #map59} : () -> () diff --git a/mlir/test/Transforms/loop-fusion-2.mlir b/mlir/test/Transforms/loop-fusion-2.mlir index c214e44296df..ccd701b3dc98 100644 --- a/mlir/test/Transforms/loop-fusion-2.mlir +++ b/mlir/test/Transforms/loop-fusion-2.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL -// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. +// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir // Part IV of fusion tests in mlir/test/Transforms/loop-fusion-4.mlir @@ -576,9 +576,9 @@ func @fuse_across_varying_dims_complex(%arg0: f32) { } // MAXIMAL-DAG: [[$MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 72 + d1) floordiv 2304)> // MAXIMAL-DAG: [[$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 72 + d1) mod 2304) floordiv 1152)> -// MAXIMAL-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) floordiv 9) floordiv 8)> -// MAXIMAL-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) floordiv 3)> -// MAXIMAL-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) mod 3)> +// MAXIMAL-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> ((((d0 * 72 + d1) mod 1152) floordiv 9) floordiv 8)> +// MAXIMAL-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> ((d1 mod 9) floordiv 3)> +// MAXIMAL-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (d1 mod 3)> // MAXIMAL-DAG: [[$MAP7:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 16 + d1)> // MAXIMAL-DAG: [[$MAP8:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 16 - d1 + 15)> // MAXIMAL-LABEL: func @fuse_across_varying_dims_complex diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 3086c682a8c8..5bef80ef07ba 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s -// Part II of fusion tests in mlir/test/Transforms/loop-fusion=2.mlir. +// Part II of fusion tests in mlir/test/Transforms/loop-fusion=2.mlir. // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir // Part IV of fusion tests in mlir/test/Transforms/loop-fusion-4.mlir @@ -737,15 +737,15 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // // CHECK-DAG: [[$MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 9 + d1) floordiv 288)> // CHECK-DAG: [[$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 288) floordiv 144)> -// CHECK-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> ((((d0 * 9 + d1) mod 288) mod 144) floordiv 48)> -// CHECK-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) floordiv 16)> -// CHECK-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) mod 16)> +// CHECK-DAG: [[$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 144) floordiv 48)> +// CHECK-DAG: [[$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 48) floordiv 16)> +// CHECK-DAG: [[$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 9 + d1) mod 16)> // CHECK-DAG: [[$MAP11:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 9 + d1)> // CHECK-DAG: [[$MAP12:#map[0-9]+]] = affine_map<(d0) -> (d0 floordiv 288)> // CHECK-DAG: [[$MAP13:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 288) floordiv 144)> -// CHECK-DAG: [[$MAP14:#map[0-9]+]] = affine_map<(d0) -> (((d0 mod 288) mod 144) floordiv 48)> -// CHECK-DAG: [[$MAP15:#map[0-9]+]] = affine_map<(d0) -> ((((d0 mod 288) mod 144) mod 48) floordiv 16)> -// CHECK-DAG: [[$MAP16:#map[0-9]+]] = affine_map<(d0) -> ((((d0 mod 288) mod 144) mod 48) mod 16)> +// CHECK-DAG: [[$MAP14:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 144) floordiv 48)> +// CHECK-DAG: [[$MAP15:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 48) floordiv 16)> +// CHECK-DAG: [[$MAP16:#map[0-9]+]] = affine_map<(d0) -> (d0 mod 16)> // CHECK-DAG: [[$MAP17:#map[0-9]+]] = affine_map<(d0) -> (0)> // @@ -761,7 +761,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: affine.apply [[$MAP3]](%{{.*}}, %{{.*}}) // CHECK-NEXT: affine.apply [[$MAP4]](%{{.*}}, %{{.*}}) // CHECK-NEXT: "foo"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (index, index, index, index, index, index) -> i32 -// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, (((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) floordiv 48, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) floordiv 16, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32> // CHECK-NEXT: affine.apply [[$MAP11]](%{{.*}}, %{{.*}}) // CHECK-NEXT: affine.apply [[$MAP12]](%{{.*}}) // CHECK-NEXT: affine.apply [[$MAP13]](%{{.*}}) @@ -769,7 +769,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: affine.apply [[$MAP15]](%{{.*}}) // CHECK-NEXT: affine.apply [[$MAP16]](%{{.*}}) // CHECK-NEXT: affine.apply [[$MAP17]](%{{.*}}) -// CHECK-NEXT: affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, (((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) floordiv 48, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) floordiv 16, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xi32> // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xi32> // CHECK-NEXT: muli %{{.*}}, %{{.*}} : i32