Add missing linalg.batch_vecmat named op (#70218)

Linalg currently has these named ops:
* `matmul`
* `matvec`
* `vecmat`
* `batch_matmul`
* `batch_matvec`

But it does not have:
* `batch_vecmat`

This PRs adds that for consistency, and I have a short-term need for it
( https://github.com/openxla/iree/issues/15158 ), so not having this
would cause some contortion on my end.
This commit is contained in:
bjacob 2023-10-25 11:41:24 -04:00 committed by GitHub
parent 8e00d59dce
commit 8c8336fcad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 111 additions and 0 deletions

View File

@ -1796,6 +1796,74 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_vecmat
cpp_class_name: BatchVecmatOp
doc: |-
Performs a batched matrix-vector multiplication.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
kind: input_tensor
type_var: T1
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- !LinalgOperandDefConfig
name: B
kind: input_tensor
type_var: T2
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
- !LinalgOperandDefConfig
name: C
kind: output_tensor
type_var: U
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2, d1)>
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
iterator_types:
- parallel
- parallel
- reduction
assignments:
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
cpp_class_name: DotOp

View File

@ -517,6 +517,24 @@ def batch_matvec(
)
@linalg_structured_op
def batch_vecmat(
A=TensorDef(T1, Batch, S.K),
B=TensorDef(T2, Batch, S.K, S.N),
C=TensorDef(U, Batch, S.N, output=True),
):
"""Performs a batched matrix-vector multiplication.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.b, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
U, B[D.b, D.k, D.n]
)
@linalg_structured_op
def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
"""Performs a dot product of two vectors to a scalar result.

View File

@ -251,6 +251,31 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
// -----
func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>, %out: memref<?x?xf32>) {
linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
outs(%out: memref<?x?xf32>)
return
}
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK: @generalize_batch_vecmat
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
// CHECK: %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
// CHECK: %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
// CHECK: linalg.yield %[[ADD]] : f32
// -----
func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
outs(%out: memref<8x8xf32>)