mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-24 06:10:12 +00:00
[mlir][linalg] Relax convolution vectorization to support mixed types
Support the case where convolution does float extension of the inputs. Differential Revision: https://reviews.llvm.org/D127925
This commit is contained in:
parent
6ed81ec164
commit
046ebeb605
@ -1374,10 +1374,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
||||
maybeKind = getCombinerOpKind(reduceOp);
|
||||
if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
|
||||
return;
|
||||
maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
|
||||
if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
|
||||
// Check for single `mul` predecessor. The `mul` operands must be block
|
||||
// arguments or extension of block arguments.
|
||||
Operation *mulOp = nullptr;
|
||||
for (Value operand : reduceOp->getOperands()) {
|
||||
if (operand.isa<BlockArgument>())
|
||||
continue;
|
||||
if (mulOp)
|
||||
return;
|
||||
mulOp = operand.getDefiningOp();
|
||||
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
|
||||
return;
|
||||
}
|
||||
if (!mulOp)
|
||||
return;
|
||||
|
||||
for (Value operand : mulOp->getOperands()) {
|
||||
if (Operation *def = operand.getDefiningOp()) {
|
||||
if (!isa<arith::ExtFOp>(def))
|
||||
return;
|
||||
operand = def->getOperand(0);
|
||||
}
|
||||
if (!operand.isa<BlockArgument>())
|
||||
return;
|
||||
}
|
||||
// The op is now known to be valid.
|
||||
valid = true;
|
||||
}
|
||||
|
@ -224,3 +224,29 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
|
||||
|
||||
// Write the result back in one shot.
|
||||
// CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {
|
||||
linalg.conv_1d_nwc_wcf
|
||||
{dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
|
||||
ins(%input, %filter : memref<1x2x3xf16>, memref<1x3x2xf16>)
|
||||
outs(%output : memref<1x2x2xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @conv_1d_nwc_wcf_mixed_type_memref
|
||||
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xf16>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xf16>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)
|
||||
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
|
||||
|
||||
/// Read the whole data in one shot.
|
||||
// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
|
||||
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
|
||||
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
|
||||
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x2xf16>
|
||||
// CHECK: %[[CONT:.*]] = vector.contract
|
||||
// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
|
||||
// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
|
||||
|
Loading…
Reference in New Issue
Block a user