[mlir][scf] Retain existing attributes in scf.for transforms

These attributes can carry useful information, e.g., pipelines
might use them to organize and chain patterns.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D126320
This commit is contained in:
Lei Zhang 2022-05-25 10:45:26 -04:00
parent 5a2dbe49be
commit 413fbb045d
4 changed files with 9 additions and 5 deletions

View File

@ -660,6 +660,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newIterArgs);
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
// Replace the null placeholders with newly constructed values.
@ -802,6 +803,7 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newIterOperands);
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
newBlock.getArguments().end());

View File

@ -491,6 +491,7 @@ struct ForOpInterface
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), initArgs);
newForOp->setAttrs(forOp->getAttrs());
ValueRange initArgsRange(initArgs);
TypeRange initArgsTypes(initArgsRange);
Block *loopBody = &newForOp.getLoopBody().front();

View File

@ -30,14 +30,14 @@ func.func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) ->
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32>
// CHECK: %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
// CHECK: scf.yield %[[ITER]] : memref<f32>
// CHECK: }
// CHECK: } {some_attr}
// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_9:.*]] : memref<f32>
// CHECK: return %[[VAL_8]] : tensor<f32>
// CHECK: }
func.func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
%ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> {
scf.yield %iter : tensor<f32>
}
} {some_attr}
return %ret : tensor<f32>
}

View File

@ -372,7 +372,7 @@ func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i
%r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) {
%c = func.call @make_i32() : () -> (i32)
scf.yield %0, %c, %2 : i32, i32, i32
}
} {some_attr}
return %r#0, %r#1, %r#2 : i32, i32, i32
}
@ -382,7 +382,7 @@ func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i
// CHECK-NEXT: %[[r1:.*]] = scf.for {{.*}} iter_args(%arg4 = %[[a]]) -> (i32) {
// CHECK-NEXT: %[[c:.*]] = func.call @make_i32() : () -> i32
// CHECK-NEXT: scf.yield %[[c]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: } {some_attr}
// CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32
// -----
@ -846,11 +846,12 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32
// CHECK: %[[DONE:.*]] = func.call @do(%[[CAST]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
// CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32>
// CHECK: } {some_attr}
%0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
%1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
%2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %2 : tensor<?x?xf32>
}
} {some_attr}
// CHECK-NOT: tensor.cast
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
// CHECK: return %[[RES]] : tensor<1024x1024xf32>