mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-01 18:12:44 +00:00
[MLIR][linalg] Make integer matmul ops cast before multiplying
Right now they multiply before casting which means they would frequently overflow. There are various reasonable ways to do this, but until we have robust op description infra, this is a simple and safe default. More careful treatments are likely to be hardware specific, as well (e.g. using an i8*i8->i16 mul instruction). Reviewed By: nicolasvasilache, mravishankar Differential Revision: https://reviews.llvm.org/D97505
This commit is contained in:
parent
301551ae8e
commit
21bb63893e
@ -15,13 +15,13 @@ implements_interface<LinalgContractionOpInterface> :
|
||||
def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
|
||||
// TODO: ideally something closer to
|
||||
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
|
||||
C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
|
||||
C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
|
||||
}
|
||||
|
||||
ods_def<MatmulI16I16I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
|
||||
C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
|
||||
C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
|
||||
}
|
||||
|
||||
ods_def<MatmulI32I32I32Op>
|
||||
@ -39,13 +39,13 @@ def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
|
||||
ods_def<MatvecI8I8I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
|
||||
x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
|
||||
x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
|
||||
}
|
||||
|
||||
ods_def<MatvecI16I16I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
|
||||
x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
|
||||
x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
|
||||
}
|
||||
|
||||
ods_def<MatvecI32I32I32Op>
|
||||
@ -63,13 +63,13 @@ def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
|
||||
ods_def<VecmatI8I8I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
|
||||
x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
|
||||
x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
|
||||
}
|
||||
|
||||
ods_def<VecmatI16I16I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
|
||||
x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
|
||||
x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
|
||||
}
|
||||
|
||||
ods_def<VecmatI32I32I32Op>
|
||||
@ -87,13 +87,13 @@ def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
|
||||
ods_def<DotI8I8I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
|
||||
C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
|
||||
C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
|
||||
}
|
||||
|
||||
ods_def<DotI16I16I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
|
||||
C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
|
||||
C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
|
||||
}
|
||||
|
||||
ods_def<DotI32I32I32Op>
|
||||
@ -112,14 +112,14 @@ ods_def<BatchMatmulI8I8I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
|
||||
C(b, m, n) =
|
||||
std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
|
||||
std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
|
||||
}
|
||||
|
||||
ods_def<BatchMatmulI16I16I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
|
||||
C(b, m, n) =
|
||||
std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
|
||||
std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
|
||||
}
|
||||
|
||||
|
||||
|
@ -373,17 +373,18 @@ func @matmul_tensors(
|
||||
// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32>
|
||||
func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) {
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8>
|
||||
// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32>
|
||||
// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8>
|
||||
// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8>
|
||||
// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32>
|
||||
// CHECK-DAG: %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32>
|
||||
// CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32>
|
||||
//
|
||||
// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
|
||||
// a later canonicalization fuses the add into vector.contract.
|
||||
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]]
|
||||
// CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
|
||||
// CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
|
||||
// CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>
|
||||
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0_32]], %[[V1_32]], %[[VEC_C0]]
|
||||
// CHECK-SAME: vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32>
|
||||
// CHECK: %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32>
|
||||
// CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]}
|
||||
// CHECK-SAME: vector<4x12xi32>, memref<4x12xi32>
|
||||
linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>)
|
||||
|
Loading…
Reference in New Issue
Block a user