[mlir][vector] Support transfer op on tensor optimizations

Support store to load forwarding and dead store transformations for transfer op
on tensor.

Differential Revision: https://reviews.llvm.org/D94148
This commit is contained in:
Thomas Raoux 2021-01-06 09:34:50 -08:00
parent f9e858f5fd
commit 080943f752
2 changed files with 131 additions and 15 deletions

View File

@ -34,13 +34,33 @@ static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
return op;
}
/// Return true if the transfer_write fully writes the data accessed by the
/// transfer_read.
static bool transferEncompasses(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
return !defWrite.hasMaskedDim() && defWrite.indices() == read.indices() &&
defWrite.getVectorType() == read.getVectorType() &&
defWrite.permutation_map() == read.permutation_map();
}
/// Return true if the write op fully over-write the priorWrite transfer_write
/// op.
static bool transferEncompasses(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite) {
return priorWrite.indices() == write.indices() &&
priorWrite.getVectorType() == write.getVectorType() &&
priorWrite.permutation_map() == write.permutation_map();
}
namespace {
class TransferOptimization {
public:
TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
void deadStoreOp(vector::TransferWriteOp);
void deadStoreOpTensor(vector::TransferWriteOp);
void storeToLoadForwarding(vector::TransferReadOp);
void storeToLoadForwardingTensor(vector::TransferReadOp);
void removeDeadOp() {
for (Operation *op : opToErase)
op->erase();
@ -99,9 +119,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
// Check candidate that can override the store.
if (write.indices() == nextWrite.indices() &&
write.getVectorType() == nextWrite.getVectorType() &&
write.permutation_map() == write.permutation_map() &&
if (transferEncompasses(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr ||
postDominators.postDominates(firstOverwriteCandidate, nextWrite))
@ -173,10 +191,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
cast<VectorTransferOpInterface>(write.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation())))
continue;
if (dominators.dominates(write, read) && !write.hasMaskedDim() &&
write.indices() == read.indices() &&
write.getVectorType() == read.getVectorType() &&
write.permutation_map() == read.permutation_map()) {
if (dominators.dominates(write, read) &&
transferEncompasses(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write;
else
@ -214,15 +230,62 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
opToErase.push_back(read.getOperation());
}
/// Walk up the SSA links, if any write gets fully overwritten we can skip it.
/// If it has no more uses it becomes dead.
void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) {
auto defWrite = write.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (transferEncompasses(write, defWrite)) {
write.sourceMutable().assign(defWrite.source());
if (defWrite->use_empty())
opToErase.push_back(defWrite.getOperation());
return;
}
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(write.getOperation())))
break;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
}
/// Walk up the SSA links, if any write fully match the written vector we can
/// replace the read by the vector. The read becomes dead and can be removed.
void TransferOptimization::storeToLoadForwardingTensor(
vector::TransferReadOp read) {
auto defWrite = read.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (transferEncompasses(defWrite, read)) {
read.replaceAllUsesWith(defWrite.vector());
opToErase.push_back(read.getOperation());
return;
}
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation())))
break;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
}
} // namespace
void mlir::vector::transferOpflowOpt(FuncOp func) {
TransferOptimization opt(func);
// Run store to load forwarding first since it can expose more dead store
// opportunity.
func.walk(
[&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); });
func.walk([&](vector::TransferReadOp read) {
if (read.getShapedType().isa<MemRefType>())
opt.storeToLoadForwarding(read);
else
opt.storeToLoadForwardingTensor(read);
});
opt.removeDeadOp();
func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); });
func.walk([&](vector::TransferWriteOp write) {
if (write.getShapedType().isa<MemRefType>())
opt.deadStoreOp(write);
else
opt.deadStoreOpTensor(write);
});
opt.removeDeadOp();
}

View File

@ -13,16 +13,16 @@ func @forward_dead_store(%arg0: i1, %arg1 : memref<4x4xf32>,
%c4 = constant 4 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
vector<1x4xf32>, memref<4x4xf32>
%0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
%0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
memref<4x4xf32>, vector<1x4xf32>
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-> (vector<1x4xf32>) {
%1 = addf %acc, %acc : vector<1x4xf32>
scf.yield %1 : vector<1x4xf32>
}
vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
vector<1x4xf32>, memref<4x4xf32>
return
}
@ -103,7 +103,7 @@ func @forward_nested_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// CHECK: vector.transfer_read
// CHECK: return
func @dead_store_region(%arg0: i1, %arg1 : memref<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index)
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index)
-> (vector<1x4xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
@ -184,3 +184,56 @@ func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>,
return
}
// CHECK-LABEL: func @forward_dead_store_tensor
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
// CHECK: scf.for
// CHECK: }
// CHECK: %[[VTW:.*]] = vector.transfer_write
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} :
tensor<4x4xf32>, vector<1x4xf32>
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-> (vector<1x4xf32>) {
%1 = addf %acc, %acc : vector<1x4xf32>
scf.yield %1 : vector<1x4xf32>
}
%w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w1 : tensor<4x4xf32>
}
// CHECK-LABEL: func @forward_dead_store_negative_tensor
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
// CHECK: scf.for
// CHECK: }
// CHECK: %[[VTW:.*]] = vector.transfer_write
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg1[%c1, %i] {masked = [false, false]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} :
tensor<4x4xf32>, vector<1x4xf32>
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-> (vector<1x4xf32>) {
%1 = addf %acc, %acc : vector<1x4xf32>
scf.yield %1 : vector<1x4xf32>
}
%w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w1 : tensor<4x4xf32>
}