From 2c4f5690ab5e435691aafe554725dbbd521b3754 Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Tue, 22 Jun 2021 12:50:10 -0700 Subject: [PATCH] Add linalg.batch_matvec named op Similarly to batch_mat vec outer most dim is a batching dim and this op does |b| matrix-vector-products : C[b, i] = sum_k(A[b, i, k] * B[b, k]) Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D104739 --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 62 +++++++++++++++++++ .../linalg/opdsl/ops/core_named_ops.py | 15 +++++ .../Dialect/Linalg/generalize-named-ops.mlir | 25 ++++++++ 3 files changed, 102 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index e536b44fe6fb..8781e16bba34 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -247,6 +247,68 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: A --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: batch_matvec + cpp_class_name: BatchMatvecOp + doc: |- + Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)> + - !LinalgOperandDefConfig + name: B + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + - !LinalgOperandDefConfig + name: C + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + iterator_types: + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: B +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: dot cpp_class_name: DotOp diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 5867109279aa..561cd2e7d08d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -66,6 +66,21 @@ def vecmat( x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) +@linalg_structured_op +def batch_matvec( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K), + C=TensorDef(U, Batch, S.M, output=True)): + """Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.k) + implements(ContractionOpInterface) + C[D.b, D.m] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k]) + + @linalg_structured_op def dot( A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index 412309a0f743..405c7b156da6 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -490,3 +490,28 @@ func @generalize_fill(%output: memref, %value : f32) { // CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) // CHECK-NEXT: linalg.yield %[[BBARG0]] : f32 + +// ----- + +func @generalize_batch_matm_vec(%lhs : memref, %rhs: memref, %out: memref) { + linalg.batch_matvec ins(%lhs, %rhs: memref, memref) + outs(%out: memref) + return +} +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: @generalize_batch_matm_vec + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref) +// CHECK-SAME: outs(%{{.+}} : memref) +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32) +// CHECK: %[[BBARG0_F32:.+]] = sitofp %[[BBARG0]] : i8 to f32 +// CHECK: %[[BBARG1_F32:.+]] = sitofp %[[BBARG1]] : i8 to f32 +// CHECK: %[[MUL:.+]] = mulf %[[BBARG0_F32]], %[[BBARG1_F32]] +// CHECK: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] +// CHECK: linalg.yield %[[ADD]] : f32