mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-01 01:31:26 +00:00
Add missing linalg.batch_vecmat
named op (#70218)
Linalg currently has these named ops: * `matmul` * `matvec` * `vecmat` * `batch_matmul` * `batch_matvec` But it does not have: * `batch_vecmat` This PRs adds that for consistency, and I have a short-term need for it ( https://github.com/openxla/iree/issues/15158 ), so not having this would cause some contortion on my end.
This commit is contained in:
parent
8e00d59dce
commit
8c8336fcad
@ -1796,6 +1796,74 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: batch_vecmat
|
||||
cpp_class_name: BatchVecmatOp
|
||||
doc: |-
|
||||
Performs a batched matrix-vector multiplication.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
kind: input_tensor
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: B
|
||||
kind: input_tensor
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: C
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2, d1)>
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: dot
|
||||
cpp_class_name: DotOp
|
||||
|
@ -517,6 +517,24 @@ def batch_matvec(
|
||||
)
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def batch_vecmat(
|
||||
A=TensorDef(T1, Batch, S.K),
|
||||
B=TensorDef(T2, Batch, S.K, S.N),
|
||||
C=TensorDef(U, Batch, S.N, output=True),
|
||||
):
|
||||
"""Performs a batched matrix-vector multiplication.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.b, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
|
||||
U, B[D.b, D.k, D.n]
|
||||
)
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
|
||||
"""Performs a dot product of two vectors to a scalar result.
|
||||
|
@ -251,6 +251,31 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
|
||||
|
||||
// -----
|
||||
|
||||
func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>, %out: memref<?x?xf32>) {
|
||||
linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
|
||||
outs(%out: memref<?x?xf32>)
|
||||
return
|
||||
}
|
||||
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
|
||||
// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK: @generalize_batch_vecmat
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
|
||||
// CHECK: %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
|
||||
// CHECK: %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
|
||||
// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
|
||||
// CHECK: linalg.yield %[[ADD]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
|
||||
linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
|
||||
outs(%out: memref<8x8xf32>)
|
||||
|
Loading…
Reference in New Issue
Block a user