[MLIR][linalg] Make integer matmul ops cast before multiplying

Right now they multiply before casting which means they would frequently
overflow. There are various reasonable ways to do this, but until we
have robust op description infra, this is a simple and safe default. More
careful treatments are likely to be hardware specific, as well (e.g.
using an i8*i8->i16 mul instruction).

Reviewed By: nicolasvasilache, mravishankar

Differential Revision: https://reviews.llvm.org/D97505
This commit is contained in:
Geoffrey Martin-Noble 2021-02-25 17:20:25 -08:00
parent 301551ae8e
commit 21bb63893e
2 changed files with 16 additions and 15 deletions

View File

@ -15,13 +15,13 @@ implements_interface<LinalgContractionOpInterface> :
def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
// TODO: ideally something closer to
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
}
ods_def<MatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
}
ods_def<MatmulI32I32I32Op>
@ -39,13 +39,13 @@ def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
ods_def<MatvecI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
}
ods_def<MatvecI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
}
ods_def<MatvecI32I32I32Op>
@ -63,13 +63,13 @@ def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
ods_def<VecmatI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
}
ods_def<VecmatI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
}
ods_def<VecmatI32I32I32Op>
@ -87,13 +87,13 @@ def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
ods_def<DotI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
}
ods_def<DotI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
}
ods_def<DotI32I32I32Op>
@ -112,14 +112,14 @@ ods_def<BatchMatmulI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
C(b, m, n) =
std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
}
ods_def<BatchMatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
C(b, m, n) =
std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
}

View File

@ -373,17 +373,18 @@ func @matmul_tensors(
// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32>
func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) {
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8>
// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32>
// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8>
// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8>
// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32>
// CHECK-DAG: %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32>
// CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32>
//
// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
// a later canonicalization fuses the add into vector.contract.
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]]
// CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
// CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
// CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0_32]], %[[V1_32]], %[[VEC_C0]]
// CHECK-SAME: vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32>
// CHECK: %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32>
// CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]}
// CHECK-SAME: vector<4x12xi32>, memref<4x12xi32>
linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>)