mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-03 11:23:58 +00:00
[mlir][vector] Support transfer op on tensor optimizations
Support store to load forwarding and dead store transformations for transfer op on tensor. Differential Revision: https://reviews.llvm.org/D94148
This commit is contained in:
parent
f9e858f5fd
commit
080943f752
@ -34,13 +34,33 @@ static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
|
||||
return op;
|
||||
}
|
||||
|
||||
/// Return true if the transfer_write fully writes the data accessed by the
|
||||
/// transfer_read.
|
||||
static bool transferEncompasses(vector::TransferWriteOp defWrite,
|
||||
vector::TransferReadOp read) {
|
||||
return !defWrite.hasMaskedDim() && defWrite.indices() == read.indices() &&
|
||||
defWrite.getVectorType() == read.getVectorType() &&
|
||||
defWrite.permutation_map() == read.permutation_map();
|
||||
}
|
||||
|
||||
/// Return true if the write op fully over-write the priorWrite transfer_write
|
||||
/// op.
|
||||
static bool transferEncompasses(vector::TransferWriteOp write,
|
||||
vector::TransferWriteOp priorWrite) {
|
||||
return priorWrite.indices() == write.indices() &&
|
||||
priorWrite.getVectorType() == write.getVectorType() &&
|
||||
priorWrite.permutation_map() == write.permutation_map();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class TransferOptimization {
|
||||
public:
|
||||
TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
|
||||
void deadStoreOp(vector::TransferWriteOp);
|
||||
void deadStoreOpTensor(vector::TransferWriteOp);
|
||||
void storeToLoadForwarding(vector::TransferReadOp);
|
||||
void storeToLoadForwardingTensor(vector::TransferReadOp);
|
||||
void removeDeadOp() {
|
||||
for (Operation *op : opToErase)
|
||||
op->erase();
|
||||
@ -99,9 +119,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
||||
continue;
|
||||
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
|
||||
// Check candidate that can override the store.
|
||||
if (write.indices() == nextWrite.indices() &&
|
||||
write.getVectorType() == nextWrite.getVectorType() &&
|
||||
write.permutation_map() == write.permutation_map() &&
|
||||
if (transferEncompasses(nextWrite, write) &&
|
||||
postDominators.postDominates(nextWrite, write)) {
|
||||
if (firstOverwriteCandidate == nullptr ||
|
||||
postDominators.postDominates(firstOverwriteCandidate, nextWrite))
|
||||
@ -173,10 +191,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
||||
cast<VectorTransferOpInterface>(write.getOperation()),
|
||||
cast<VectorTransferOpInterface>(read.getOperation())))
|
||||
continue;
|
||||
if (dominators.dominates(write, read) && !write.hasMaskedDim() &&
|
||||
write.indices() == read.indices() &&
|
||||
write.getVectorType() == read.getVectorType() &&
|
||||
write.permutation_map() == read.permutation_map()) {
|
||||
if (dominators.dominates(write, read) &&
|
||||
transferEncompasses(write, read)) {
|
||||
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
|
||||
lastwrite = write;
|
||||
else
|
||||
@ -214,15 +230,62 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
||||
opToErase.push_back(read.getOperation());
|
||||
}
|
||||
|
||||
/// Walk up the SSA links, if any write gets fully overwritten we can skip it.
|
||||
/// If it has no more uses it becomes dead.
|
||||
void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) {
|
||||
auto defWrite = write.source().getDefiningOp<vector::TransferWriteOp>();
|
||||
while (defWrite) {
|
||||
if (transferEncompasses(write, defWrite)) {
|
||||
write.sourceMutable().assign(defWrite.source());
|
||||
if (defWrite->use_empty())
|
||||
opToErase.push_back(defWrite.getOperation());
|
||||
return;
|
||||
}
|
||||
if (!isDisjointTransferIndices(
|
||||
cast<VectorTransferOpInterface>(defWrite.getOperation()),
|
||||
cast<VectorTransferOpInterface>(write.getOperation())))
|
||||
break;
|
||||
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
|
||||
}
|
||||
}
|
||||
|
||||
/// Walk up the SSA links, if any write fully match the written vector we can
|
||||
/// replace the read by the vector. The read becomes dead and can be removed.
|
||||
void TransferOptimization::storeToLoadForwardingTensor(
|
||||
vector::TransferReadOp read) {
|
||||
auto defWrite = read.source().getDefiningOp<vector::TransferWriteOp>();
|
||||
while (defWrite) {
|
||||
if (transferEncompasses(defWrite, read)) {
|
||||
read.replaceAllUsesWith(defWrite.vector());
|
||||
opToErase.push_back(read.getOperation());
|
||||
return;
|
||||
}
|
||||
if (!isDisjointTransferIndices(
|
||||
cast<VectorTransferOpInterface>(defWrite.getOperation()),
|
||||
cast<VectorTransferOpInterface>(read.getOperation())))
|
||||
break;
|
||||
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::transferOpflowOpt(FuncOp func) {
|
||||
TransferOptimization opt(func);
|
||||
// Run store to load forwarding first since it can expose more dead store
|
||||
// opportunity.
|
||||
func.walk(
|
||||
[&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); });
|
||||
func.walk([&](vector::TransferReadOp read) {
|
||||
if (read.getShapedType().isa<MemRefType>())
|
||||
opt.storeToLoadForwarding(read);
|
||||
else
|
||||
opt.storeToLoadForwardingTensor(read);
|
||||
});
|
||||
opt.removeDeadOp();
|
||||
func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); });
|
||||
func.walk([&](vector::TransferWriteOp write) {
|
||||
if (write.getShapedType().isa<MemRefType>())
|
||||
opt.deadStoreOp(write);
|
||||
else
|
||||
opt.deadStoreOpTensor(write);
|
||||
});
|
||||
opt.removeDeadOp();
|
||||
}
|
||||
|
@ -13,16 +13,16 @@ func @forward_dead_store(%arg0: i1, %arg1 : memref<4x4xf32>,
|
||||
%c4 = constant 4 : index
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
|
||||
vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
|
||||
vector<1x4xf32>, memref<4x4xf32>
|
||||
%0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
|
||||
%0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
|
||||
memref<4x4xf32>, vector<1x4xf32>
|
||||
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
|
||||
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
|
||||
-> (vector<1x4xf32>) {
|
||||
%1 = addf %acc, %acc : vector<1x4xf32>
|
||||
scf.yield %1 : vector<1x4xf32>
|
||||
}
|
||||
vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
|
||||
vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
|
||||
vector<1x4xf32>, memref<4x4xf32>
|
||||
return
|
||||
}
|
||||
@ -103,7 +103,7 @@ func @forward_nested_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
|
||||
// CHECK: vector.transfer_read
|
||||
// CHECK: return
|
||||
func @dead_store_region(%arg0: i1, %arg1 : memref<4x4xf32>,
|
||||
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index)
|
||||
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index)
|
||||
-> (vector<1x4xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
@ -184,3 +184,56 @@ func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>,
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @forward_dead_store_tensor
|
||||
// CHECK-NOT: vector.transfer_write
|
||||
// CHECK-NOT: vector.transfer_read
|
||||
// CHECK: scf.for
|
||||
// CHECK: }
|
||||
// CHECK: %[[VTW:.*]] = vector.transfer_write
|
||||
// CHECK: return %[[VTW]] : tensor<4x4xf32>
|
||||
func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
|
||||
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c4 = constant 4 : index
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
|
||||
vector<1x4xf32>, tensor<4x4xf32>
|
||||
%0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} :
|
||||
tensor<4x4xf32>, vector<1x4xf32>
|
||||
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
|
||||
-> (vector<1x4xf32>) {
|
||||
%1 = addf %acc, %acc : vector<1x4xf32>
|
||||
scf.yield %1 : vector<1x4xf32>
|
||||
}
|
||||
%w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} :
|
||||
vector<1x4xf32>, tensor<4x4xf32>
|
||||
return %w1 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @forward_dead_store_negative_tensor
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK: vector.transfer_read
|
||||
// CHECK: scf.for
|
||||
// CHECK: }
|
||||
// CHECK: %[[VTW:.*]] = vector.transfer_write
|
||||
// CHECK: return %[[VTW]] : tensor<4x4xf32>
|
||||
func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
|
||||
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c4 = constant 4 : index
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%w0 = vector.transfer_write %v0, %arg1[%c1, %i] {masked = [false, false]} :
|
||||
vector<1x4xf32>, tensor<4x4xf32>
|
||||
%0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} :
|
||||
tensor<4x4xf32>, vector<1x4xf32>
|
||||
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
|
||||
-> (vector<1x4xf32>) {
|
||||
%1 = addf %acc, %acc : vector<1x4xf32>
|
||||
scf.yield %1 : vector<1x4xf32>
|
||||
}
|
||||
%w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} :
|
||||
vector<1x4xf32>, tensor<4x4xf32>
|
||||
return %w1 : tensor<4x4xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user