[mlir][linalg] Set explicit insertion point in pad_tensor patterns.

Insert ops replacing pad_tensor in front of the associated tansfer_write / insert_slice op. Otherwise we may end up with invalid ir if one of the remaining tansfer_write / insert_slice operands is defined after the pad_tensor op.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D106162
This commit is contained in:
Tobias Gysi 2021-07-19 08:16:28 +00:00
parent 73e4b5cfa8
commit 3f8f292330
2 changed files with 34 additions and 16 deletions

View File

@ -870,6 +870,9 @@ struct PadTensorOpVectorizationWithTransferWritePattern
// trimPadding must remove the amount of padding that was added earlier.
if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
// Insert the new TransferWriteOp at position of the old TransferWriteOp.
rewriter.setInsertionPoint(xferOp);
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(),
@ -1014,6 +1017,10 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
}))
return failure();
// Insert the TransferReadOp and TransferWriteOp at the position of the
// InsertSliceOp.
rewriter.setInsertionPoint(insertOp);
// Generate TransferReadOp: Read entire source tensor and add high padding.
SmallVector<Value> readIndices(
vecRank, rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));

View File

@ -621,38 +621,44 @@ func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
// -----
func private @make_vector() -> vector<7x9xf32>
// CHECK-LABEL: func @pad_and_transfer_write_static
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: vector<7x9xf32>
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK-NOT: linalg.pad_tensor
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
// CHECK: return %[[RESULT]]
func @pad_and_transfer_write_static(
%arg0: tensor<5x6xf32>, %arg1: vector<7x9xf32>) -> tensor<5x6xf32> {
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
%c0 = constant 0 : index
%c5 = constant 5.0 : f32
%0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] {
^bb0(%arg2: index, %arg3: index):
linalg.yield %c5 : f32
} : tensor<5x6xf32> to tensor<10x13xf32>
%1 = vector.transfer_write %arg1, %0[%c0, %c0]
%1 = call @make_vector() : () -> vector<7x9xf32>
%2 = vector.transfer_write %1, %0[%c0, %c0]
: vector<7x9xf32>, tensor<10x13xf32>
%2 = tensor.extract_slice %1[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
return %2 : tensor<5x6xf32>
%3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
return %3 : tensor<5x6xf32>
}
// -----
func private @make_vector() -> vector<7x9xf32>
// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: vector<7x9xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index
// CHECK-NOT: linalg.pad_tensor
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor<?x?xf32> to tensor<?x6xf32>
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
// CHECK: return %[[RESULT]]
func @pad_and_transfer_write_dynamic_static(
%arg0: tensor<?x?xf32>, %arg1: vector<7x9xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
%arg0: tensor<?x?xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
%c0 = constant 0 : index
%c5 = constant 5.0 : f32
%s = tensor.extract_slice %arg0[0, 0] [%size, 6] [1, 1]
@ -661,31 +667,36 @@ func @pad_and_transfer_write_dynamic_static(
^bb0(%arg2: index, %arg3: index):
linalg.yield %c5 : f32
} : tensor<?x6xf32> to tensor<?x13xf32>
%1 = vector.transfer_write %arg1, %0[%c0, %c0]
%1 = call @make_vector() : () -> vector<7x9xf32>
%2 = vector.transfer_write %1, %0[%c0, %c0]
: vector<7x9xf32>, tensor<?x13xf32>
%2 = tensor.extract_slice %1[0, 0] [%size, 6] [1, 1] : tensor<?x13xf32> to tensor<?x6xf32>
return %2 : tensor<?x6xf32>
%3 = tensor.extract_slice %2[0, 0] [%size, 6] [1, 1] : tensor<?x13xf32> to tensor<?x6xf32>
return %3 : tensor<?x6xf32>
}
// -----
func private @make_vector() -> tensor<12x13xf32>
// CHECK-LABEL: func @pad_and_insert_slice
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32>
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK-NOT: linalg.pad_tensor
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C5:.*]] = constant 5.0
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> tensor<12x13xf32>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
// CHECK: return %[[WRITE]]
func @pad_and_insert_slice(
%arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> {
%arg0: tensor<5x6xf32>) -> tensor<12x13xf32> {
%c0 = constant 0 : index
%c5 = constant 5.0 : f32
%0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] {
^bb0(%arg2: index, %arg3: index):
linalg.yield %c5 : f32
} : tensor<5x6xf32> to tensor<7x9xf32>
%r = tensor.insert_slice %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32>
%1 = call @make_vector() : () -> tensor<12x13xf32>
%r = tensor.insert_slice %0 into %1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32>
return %r : tensor<12x13xf32>
}