mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-27 07:31:28 +00:00
Add linalg.batch_matvec named op
Similarly to batch_mat vec outer most dim is a batching dim and this op does |b| matrix-vector-products : C[b, i] = sum_k(A[b, i, k] * B[b, k]) Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D104739
This commit is contained in:
parent
03051f7ac8
commit
2c4f5690ab
@ -247,6 +247,68 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: batch_matvec
|
||||
cpp_class_name: BatchMatvecOp
|
||||
doc: |-
|
||||
Performs a batched matrix-vector multiplication.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
usage: InputOperand
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: B
|
||||
usage: InputOperand
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: C
|
||||
usage: OutputOperand
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1, d2)>
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: dot
|
||||
cpp_class_name: DotOp
|
||||
|
@ -66,6 +66,21 @@ def vecmat(
|
||||
x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def batch_matvec(
|
||||
A=TensorDef(T1, Batch, S.M, S.K),
|
||||
B=TensorDef(T2, Batch, S.K),
|
||||
C=TensorDef(U, Batch, S.M, output=True)):
|
||||
"""Performs a batched matrix-vector multiplication.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.b, D.m, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.b, D.m] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def dot(
|
||||
A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
|
||||
|
@ -490,3 +490,28 @@ func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
|
||||
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
|
||||
// CHECK-NEXT: linalg.yield %[[BBARG0]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi8>, %out: memref<?x?xf32>) {
|
||||
linalg.batch_matvec ins(%lhs, %rhs: memref<?x?x?xi8>, memref<?x?xi8>)
|
||||
outs(%out: memref<?x?xf32>)
|
||||
return
|
||||
}
|
||||
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK: @generalize_batch_matm_vec
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xi8>, memref<?x?xi8>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
|
||||
// CHECK: %[[BBARG0_F32:.+]] = sitofp %[[BBARG0]] : i8 to f32
|
||||
// CHECK: %[[BBARG1_F32:.+]] = sitofp %[[BBARG1]] : i8 to f32
|
||||
// CHECK: %[[MUL:.+]] = mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
|
||||
// CHECK: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]]
|
||||
// CHECK: linalg.yield %[[ADD]] : f32
|
||||
|
Loading…
Reference in New Issue
Block a user