[mlir][linalg] Relax convolution vectorization to support mixed types

Support the case where convolution does float extension of the inputs.

Differential Revision: https://reviews.llvm.org/D127925
This commit is contained in:
Thomas Raoux 2022-06-16 00:52:25 +00:00
parent 6ed81ec164
commit 046ebeb605
2 changed files with 48 additions and 3 deletions

View File

@ -1374,10 +1374,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
return;
maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
// Check for single `mul` predecessor. The `mul` operands must be block
// arguments or extension of block arguments.
Operation *mulOp = nullptr;
for (Value operand : reduceOp->getOperands()) {
if (operand.isa<BlockArgument>())
continue;
if (mulOp)
return;
mulOp = operand.getDefiningOp();
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
return;
}
if (!mulOp)
return;
for (Value operand : mulOp->getOperands()) {
if (Operation *def = operand.getDefiningOp()) {
if (!isa<arith::ExtFOp>(def))
return;
operand = def->getOperand(0);
}
if (!operand.isa<BlockArgument>())
return;
}
// The op is now known to be valid.
valid = true;
}

View File

@ -224,3 +224,29 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
// Write the result back in one shot.
// CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// -----
func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {
linalg.conv_1d_nwc_wcf
{dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
ins(%input, %filter : memref<1x2x3xf16>, memref<1x3x2xf16>)
outs(%output : memref<1x2x2xf32>)
return
}
// CHECK: func @conv_1d_nwc_wcf_mixed_type_memref
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xf16>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xf16>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
/// Read the whole data in one shot.
// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x2xf16>
// CHECK: %[[CONT:.*]] = vector.contract
// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]