mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-29 16:41:27 +00:00
[mlir][Linalg] Add a vectorization pattern for linalg::PadTensorOp
The new pattern is exercised from the TestLinalgTransforms pass. Differential Revision: https://reviews.llvm.org/D96410
This commit is contained in:
parent
6f9db455a5
commit
bb69de3f41
@ -231,6 +231,28 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
||||
static linalg::PadTensorOp createPadScalarOp(
|
||||
Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
|
||||
ArrayRef<OpFoldResult> high, Location loc, OpBuilder & builder);
|
||||
|
||||
// Return a vector of all the static or dynamic values (low/high padding) of
|
||||
// the op.
|
||||
inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayAttr staticAttrs,
|
||||
ValueRange values) {
|
||||
SmallVector<OpFoldResult> res;
|
||||
unsigned numDynamic = 0;
|
||||
unsigned count = staticAttrs.size();
|
||||
for (unsigned idx = 0; idx < count; ++idx) {
|
||||
if (ShapedType::isDynamic(staticAttrs[idx].cast<IntegerAttr>().getInt()))
|
||||
res.push_back(values[numDynamic++]);
|
||||
else
|
||||
res.push_back(staticAttrs[idx]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
SmallVector<OpFoldResult> getMixedLowPad() {
|
||||
return getMixedPadImpl(static_low(), low());
|
||||
}
|
||||
SmallVector<OpFoldResult> getMixedHighPad() {
|
||||
return getMixedPadImpl(static_high(), high());
|
||||
}
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
|
@ -809,6 +809,16 @@ void populateLinalgConvGeneralizationPatterns(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op-specific patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`,
|
||||
/// it needs a specific pattern to vectorize.
|
||||
struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
|
||||
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PadTensorOp padOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Match and rewrite for the pattern:
|
||||
/// ```
|
||||
/// %alloc = ...
|
||||
|
@ -1213,6 +1213,10 @@ def Vector_TransferReadOp :
|
||||
OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
|
||||
"ValueRange":$indices, "AffineMap":$permutationMap,
|
||||
CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
|
||||
// Builder that sets padding to 'getMinorIdentityMap'.
|
||||
OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
|
||||
"ValueRange":$indices, "Value":$padding,
|
||||
CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
|
||||
// Builder that sets permutation map (resp. padding) to
|
||||
// 'getMinorIdentityMap' (resp. zero).
|
||||
OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
|
||||
|
@ -448,9 +448,71 @@ Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Misc. conv vectorization patterns.
|
||||
// Misc. vectorization patterns.
|
||||
//----------------------------------------------------------------------------//
|
||||
// TODO: cleanup all this.
|
||||
|
||||
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
|
||||
/// TransferWriteOp. For now, this only applies when all low and high paddings
|
||||
/// are determined to be zero.
|
||||
LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
|
||||
linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
|
||||
// Helper function to determine whether an OpFoldResult is not a zero Index.
|
||||
auto isNotZeroIndex = [](OpFoldResult ofr) {
|
||||
if (Attribute attr = ofr.dyn_cast<Attribute>())
|
||||
return attr.cast<IntegerAttr>().getInt() != 0;
|
||||
Value v = ofr.get<Value>();
|
||||
if (auto constOp = v.getDefiningOp<ConstantIntOp>())
|
||||
return constOp.getValue() != 0;
|
||||
return true;
|
||||
};
|
||||
|
||||
auto resultShapedType = padOp.result().getType().cast<ShapedType>();
|
||||
// Bail on non-static shapes.
|
||||
if (!resultShapedType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// If any pad_low is not a static 0, needs a mask. Bail for now.
|
||||
if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
|
||||
return failure();
|
||||
VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
|
||||
if (!vectorType)
|
||||
return failure();
|
||||
|
||||
// Only support padding with a constant for now, i.e. either:
|
||||
// 1. A BBarg from a different block.
|
||||
// 2. A value defined outside of the current block.
|
||||
Block &block = padOp.region().front();
|
||||
auto yieldOp = cast<YieldOp>(block.getTerminator());
|
||||
assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
|
||||
Value padValue = yieldOp.values().front();
|
||||
Operation *definingOp = padValue.getDefiningOp();
|
||||
if (definingOp && definingOp->getBlock() == &block)
|
||||
return failure();
|
||||
if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
|
||||
return failure();
|
||||
|
||||
// TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
|
||||
if (llvm::any_of(padOp.getMixedHighPad(),
|
||||
[&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
|
||||
return failure();
|
||||
|
||||
// Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
|
||||
// TransferWriteOp@[0..0].
|
||||
SmallVector<Value> indices(
|
||||
resultShapedType.getRank(),
|
||||
rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
|
||||
Value read = rewriter.create<vector::TransferReadOp>(
|
||||
padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
|
||||
Value init =
|
||||
rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
|
||||
resultShapedType.getElementType());
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
|
||||
indices);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: cleanup all the convolution vectorization patterns.
|
||||
template <class ConvOp, int N>
|
||||
LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
|
||||
ConvOp op, PatternRewriter &rewriter) const {
|
||||
|
@ -1122,8 +1122,8 @@ public:
|
||||
|
||||
} // namespace
|
||||
|
||||
void BroadcastOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
void BroadcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<BroadcastToShapeCast>(context);
|
||||
}
|
||||
|
||||
@ -2026,17 +2026,32 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
|
||||
|
||||
/// Builder that sets padding to zero.
|
||||
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
||||
VectorType vector, Value source, ValueRange indices,
|
||||
AffineMap permutationMap,
|
||||
VectorType vectorType, Value source,
|
||||
ValueRange indices, AffineMap permutationMap,
|
||||
ArrayRef<bool> maybeMasked) {
|
||||
Type elemType = source.getType().cast<ShapedType>().getElementType();
|
||||
Value padding = builder.create<ConstantOp>(result.location, elemType,
|
||||
builder.getZeroAttr(elemType));
|
||||
if (maybeMasked.empty())
|
||||
return build(builder, result, vector, source, indices, permutationMap,
|
||||
return build(builder, result, vectorType, source, indices, permutationMap,
|
||||
padding, ArrayAttr());
|
||||
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
|
||||
build(builder, result, vector, source, indices, permutationMap, padding,
|
||||
build(builder, result, vectorType, source, indices, permutationMap, padding,
|
||||
maskedArrayAttr);
|
||||
}
|
||||
|
||||
/// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
||||
VectorType vectorType, Value source,
|
||||
ValueRange indices, Value padding,
|
||||
ArrayRef<bool> maybeMasked) {
|
||||
auto permMap = getTransferMinorIdentityMap(
|
||||
source.getType().cast<ShapedType>(), vectorType);
|
||||
if (maybeMasked.empty())
|
||||
return build(builder, result, vectorType, source, indices, permMap, padding,
|
||||
ArrayAttr());
|
||||
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
|
||||
build(builder, result, vectorType, source, indices, permMap, padding,
|
||||
maskedArrayAttr);
|
||||
}
|
||||
|
||||
|
@ -390,3 +390,44 @@ func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12x
|
||||
outs(%c: memref<4x12xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pad_static
|
||||
// CHECK-NOT: linalg.pad_tensor
|
||||
func @pad_static(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]]
|
||||
// CHECK-SAME: : tensor<?x?x?xf32>, vector<2x3x4xf32>
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32>
|
||||
// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]]
|
||||
// CHECK-SAME: {masked = [false, false, false]} : vector<2x3x4xf32>, tensor<2x3x4xf32>
|
||||
%0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 0, 0] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index):
|
||||
linalg.yield %pad_value : f32
|
||||
} : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||
|
||||
// CHECK: return %[[WRITTEN]] : tensor<2x3x4xf32>
|
||||
return %0 : tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @pad_static_high_padding
|
||||
// CHECK: linalg.pad_tensor
|
||||
func @pad_static_high_padding(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
|
||||
%0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 1, 0] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index):
|
||||
linalg.yield %pad_value : f32
|
||||
} : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||
return %0 : tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @pad_dynamic
|
||||
// CHECK: linalg.pad_tensor
|
||||
func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
|
||||
%pad_value: f32) -> tensor<6x?x?x?xf32> {
|
||||
%0 = linalg.pad_tensor %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
|
||||
linalg.yield %pad_value : f32
|
||||
} : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
|
||||
return %0 : tensor<6x?x?x?xf32>
|
||||
}
|
||||
|
@ -491,6 +491,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
|
||||
patterns.insert<LinalgVectorizationPattern>(
|
||||
LinalgTransformationFilter()
|
||||
.addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
|
||||
patterns.insert<PadTensorOpVectorizationPattern>(funcOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user