[mlir][Vector] Support 0-D vectors in BroadcastOp

This changes the op to produce `AnyVectorOfAnyRank` following mostly the code for 1-D vectors.

Depends On D114598

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114550
This commit is contained in:
Michal Terepeta 2021-11-26 17:17:13 +00:00 committed by Nicolas Vasilache
parent d0f927121e
commit 7e65fc9a60
4 changed files with 89 additions and 16 deletions

View File

@ -302,7 +302,7 @@ def Vector_MultiDimReductionOp :
Results<(outs AnyType:$dest)> {
let summary = "Multi-dimensional reduction operation";
let description = [{
Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
using the given operation (add/mul/min/max for int/fp and and/or/xor for
int only).
@ -380,7 +380,7 @@ def Vector_BroadcastOp :
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyType:$source)>,
Results<(outs AnyVector:$vector)> {
Results<(outs AnyVectorOfAnyRank:$vector)> {
let summary = "broadcast operation";
let description = [{
Broadcasts the scalar or k-D vector value in the source operand

View File

@ -546,10 +546,27 @@ public:
VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
Type eltType = dstType.getElementType();
// Scalar to any vector can use splat.
if (!srcType) {
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
return success();
}
// Determine rank of source and destination.
int64_t srcRank = srcType ? srcType.getRank() : 0;
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
Value ext;
if (srcRank == 0)
ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
else
ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
return success();
}
// Duplicate this rank.
// For example:
// %x = broadcast %y : k-D to n-D, k < n
@ -560,11 +577,6 @@ public:
// %b = [%y,%y] : (n-1)-D
// %x = [%b,%b,%b,%b] : n-D
if (srcRank < dstRank) {
// Scalar to any vector can use splat.
if (srcRank == 0) {
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
return success();
}
// Duplication.
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
@ -593,14 +605,6 @@ public:
return success();
}
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank == 1) {
assert(m == 0);
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
return success();
}
// Any non-matching dimension forces a stretch along this rank.
// For example:
// %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>

View File

@ -35,6 +35,27 @@ func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8> {
// -----
func @broadcast_vec0d_from_f32(%arg0: f32) -> vector<f32> {
%0 = vector.broadcast %arg0 : f32 to vector<f32>
return %0 : vector<f32>
}
// CHECK-LABEL: @broadcast_vec0d_from_f32
// CHECK-SAME: %[[A:.*]]: f32)
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<f32>
// CHECK: return %[[T0]] : vector<f32>
// -----
func @broadcast_vec0d_from_vec0d(%arg0: vector<f32>) -> vector<f32> {
%0 = vector.broadcast %arg0 : vector<f32> to vector<f32>
return %0 : vector<f32>
}
// CHECK-LABEL: @broadcast_vec0d_from_vec0d(
// CHECK-SAME: %[[A:.*]]: vector<f32>)
// CHECK: return %[[A]] : vector<f32>
// -----
func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2xf32>
return %0 : vector<2xf32>
@ -89,6 +110,26 @@ func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
// -----
func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
%0 = vector.broadcast %arg0 : vector<f32> to vector<3x2xf32>
return %0 : vector<3x2xf32>
}
// CHECK-LABEL: @broadcast_vec2d_from_vec0d(
// CHECK-SAME: %[[A:.*]]: vector<f32>)
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<2xf32>
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32>
// CHECK: return %[[T10]] : vector<3x2xf32>
// -----
func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
return %0 : vector<3x2xf32>

View File

@ -28,6 +28,33 @@ func @splat_0d(%a: f32) {
return
}
func @broadcast_0d(%a: f32) {
%1 = vector.broadcast %a : f32 to vector<f32>
// CHECK: ( 42 )
vector.print %1: vector<f32>
%2 = vector.broadcast %1 : vector<f32> to vector<f32>
// CHECK: ( 42 )
vector.print %2: vector<f32>
%3 = vector.broadcast %1 : vector<f32> to vector<1xf32>
// CHECK: ( 42 )
vector.print %3: vector<1xf32>
%4 = vector.broadcast %1 : vector<f32> to vector<2xf32>
// CHECK: ( 42, 42 )
vector.print %4: vector<2xf32>
%5 = vector.broadcast %1 : vector<f32> to vector<2x1xf32>
// CHECK: ( ( 42 ), ( 42 ) )
vector.print %5: vector<2x1xf32>
%6 = vector.broadcast %1 : vector<f32> to vector<2x3xf32>
// CHECK: ( ( 42, 42, 42 ), ( 42, 42, 42 ) )
vector.print %6: vector<2x3xf32>
return
}
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@ -39,6 +66,7 @@ func @entry() {
%4 = arith.constant 42.0 : f32
call @splat_0d(%4) : (f32) -> ()
call @broadcast_0d(%4) : (f32) -> ()
return
}