mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-15 12:09:51 +00:00
[MLIR][Linalg] Named op 'add' element-wise
This adds the first strict element-wise named op to Linalg. The semantics here is to not allow auto-cast, broadcast semantics and to restrict the operations only to identical types. The remaining semantics must come in the form of surrounding operations on operands, to avoid ambiguity. Examples: ``` // Cast int-to-fp %0 = linalg.copy ins(%in: tensor<32x32xi32>) outs(%out: tensor<32x32xf32>) %1 = linalg.add ins(%arg, %0: tensor<32x32xf32>, tensor<32x32xf32>) outs(%0: tensor<32x32xf32>) // This can be lowered to %1 = linalg.generic {...} ins(%arg, %in: tensor<32x32xf32>, tensor<32x32xi32>) outs(%0: tensor<32x32xf32>) { ^bb0(%a: f32, %i: i32, %out: f32): %f = arith.uitofp %i : f32 %0 = arith.addf %a, %f : f32 linalg.yield %0 : f32 } // Broadcast %0 = linalg.broadcast ins(%in: tensor<32xf32>) init(%out: tensor<32x32xf32>) %1 = linalg.add ins(%arg, %0: tensor<32x32xf32>, tensor<32x32xf32>) outs(%0: tensor<32x32xf32>) // This can be lowered to #bcast_map = affine_map<(d0, d1) -> (d0)> %1 = linalg.generic {... #bcast_map] } ins(%arg, %in: tensor<32x32xf32>, tensor<32xf32>) outs(%0: tensor<32x32xf32>) { ^bb0(%a: f32, %b: f32, %out: f32): %0 = arith.addf %a, %b : f32 linalg.yield %0 : f32 } ``` Once this gets accepted, other arithmetic and maths operations will be added accordingly, with the same semantics. Differential Revision: https://reviews.llvm.org/D154500
This commit is contained in:
parent
85128d8b6a
commit
7e486d5c2d
@ -156,6 +156,55 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: rhs
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: add
|
||||
cpp_class_name: AddOp
|
||||
doc: |-
|
||||
Adds two tensors elementwise.
|
||||
|
||||
The shapes and element types must be identical. The appropriate casts,
|
||||
broadcasts and reductions should be done previously to calling this op.
|
||||
|
||||
This means reduction/broadcast/element cast semantics is explicit. Further
|
||||
passes can take that into account when lowering this code. For example,
|
||||
a `linalg.broadcast` + `linalg.add` sequence can be lowered to a
|
||||
`linalg.generic` with different affine maps for the two operands.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: lhs
|
||||
kind: input_tensor
|
||||
type_var: T
|
||||
shape_map: affine_map<() -> ()>
|
||||
- !LinalgOperandDefConfig
|
||||
name: rhs
|
||||
kind: input_tensor
|
||||
type_var: T
|
||||
shape_map: affine_map<() -> ()>
|
||||
- !LinalgOperandDefConfig
|
||||
name: out
|
||||
kind: output_tensor
|
||||
type_var: T
|
||||
shape_map: affine_map<() -> ()>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<() -> ()>
|
||||
- affine_map<() -> ()>
|
||||
- affine_map<() -> ()>
|
||||
iterator_types: []
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: out
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: lhs
|
||||
- !ScalarExpression
|
||||
scalar_arg: rhs
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: matmul
|
||||
cpp_class_name: MatmulOp
|
||||
|
@ -51,6 +51,25 @@ def elemwise_binary(
|
||||
O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def add(
|
||||
lhs=TensorDef(T1),
|
||||
rhs=TensorDef(T1),
|
||||
O=TensorDef(T1, output=True),
|
||||
):
|
||||
""" Adds two tensors elementwise.
|
||||
|
||||
The shapes and element types must be identical. The appropriate casts,
|
||||
broadcasts and reductions should be done previously to calling this op.
|
||||
|
||||
This means reduction/broadcast/element cast semantics is explicit. Further
|
||||
passes can take that into account when lowering this code. For example,
|
||||
a `linalg.broadcast` + `linalg.add` sequence can be lowered to a
|
||||
`linalg.generic` with different affine maps for the two operands.
|
||||
"""
|
||||
O[None] = lhs[None] + rhs[None]
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def matmul(
|
||||
A=TensorDef(T1, S.M, S.K),
|
||||
|
@ -286,3 +286,28 @@ func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
|
||||
%out: memref<7x14x21xf32>) {
|
||||
linalg.add ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
|
||||
outs(%out : memref<7x14x21xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
||||
// CHECK: func @generalize_add
|
||||
// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
|
||||
// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
|
||||
// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[SUM]] : f32
|
||||
|
16
mlir/test/Dialect/Linalg/named-ops-fail.mlir
Normal file
16
mlir/test/Dialect/Linalg/named-ops-fail.mlir
Normal file
@ -0,0 +1,16 @@
|
||||
// RUN: not mlir-opt -split-input-file -verify-diagnostics %s 2>&1 | FileCheck %s
|
||||
|
||||
func.func @add_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
|
||||
// CHECK: op requires the same type for all operands and results
|
||||
linalg.add ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @add_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
|
||||
// CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
|
||||
linalg.add ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
|
||||
return
|
||||
}
|
||||
|
@ -1184,3 +1184,37 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
|
||||
linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @add_dynamic
|
||||
func.func @add_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
|
||||
// CHECK: linalg.add
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
|
||||
linalg.add ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @add_static
|
||||
func.func @add_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
|
||||
// CHECK: linalg.add
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
|
||||
linalg.add ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @add_tensor
|
||||
func.func @add_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
|
||||
%0 = tensor.empty() : tensor<4x8x16xf32>
|
||||
// CHECK: linalg.add
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
|
||||
%1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
|
||||
return %1 : tensor<4x8x16xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user