[mlir][vector] Fix vector.broadcast lowering for scalable vectors (#66344)

This patch makes sure that the following case is lowered correctly
("duplication"):
```
  func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
    %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
    return %res : vector<1x[32]xf32>
  }
```
This commit is contained in:
Andrzej Warzyński 2023-09-15 16:35:47 +01:00 committed by GitHub
parent cadabb58f1
commit 57cf6896cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 2 deletions

View File

@ -84,8 +84,7 @@ public:
// %x = [%b,%b,%b,%b] : n-D
if (srcRank < dstRank) {
// Duplication.
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
VectorType resType = VectorType::Builder(dstType).dropDim(0);
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
Value result = rewriter.create<arith::ConstantOp>(

View File

@ -162,6 +162,17 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
return %0 : vector<4x3x2xf32>
}
// CHECK-LABEL: func.func @broadcast_scalable_duplication
// CHECK-SAME: %[[ARG0:.*]]: vector<[32]xf32>)
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x[32]xf32>
// CHECK: %[[RES:.*]] = vector.insert %[[ARG0]], %[[CST]] [0] : vector<[32]xf32> into vector<1x[32]xf32>
// CHECK: return %[[RES]] : vector<1x[32]xf32>
func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
%res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
return %res : vector<1x[32]xf32>
}
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
%f = transform.structured.match ops{["func.func"]} in %module_op