diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index a40d425f7f2e..6916fa78abbb 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -231,6 +231,28 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", static linalg::PadTensorOp createPadScalarOp( Type type, Value source, Value pad, ArrayRef low, ArrayRef high, Location loc, OpBuilder & builder); + + // Return a vector of all the static or dynamic values (low/high padding) of + // the op. + inline SmallVector getMixedPadImpl(ArrayAttr staticAttrs, + ValueRange values) { + SmallVector res; + unsigned numDynamic = 0; + unsigned count = staticAttrs.size(); + for (unsigned idx = 0; idx < count; ++idx) { + if (ShapedType::isDynamic(staticAttrs[idx].cast().getInt())) + res.push_back(values[numDynamic++]); + else + res.push_back(staticAttrs[idx]); + } + return res; + } + SmallVector getMixedLowPad() { + return getMixedPadImpl(static_low(), low()); + } + SmallVector getMixedHighPad() { + return getMixedPadImpl(static_high(), high()); + } }]; let builders = [ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 669f127f6434..4b5580a62abc 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -809,6 +809,16 @@ void populateLinalgConvGeneralizationPatterns( //===----------------------------------------------------------------------===// // Op-specific patterns. //===----------------------------------------------------------------------===// + +/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`, +/// it needs a specific pattern to vectorize. +struct PadTensorOpVectorizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override; +}; + /// Match and rewrite for the pattern: /// ``` /// %alloc = ... diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 8bc21b179037..1aeb92a2faf8 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1213,6 +1213,10 @@ def Vector_TransferReadOp : OpBuilderDAG<(ins "VectorType":$vector, "Value":$source, "ValueRange":$indices, "AffineMap":$permutationMap, CArg<"ArrayRef", "{}">:$maybeMasked)>, + // Builder that sets padding to 'getMinorIdentityMap'. + OpBuilderDAG<(ins "VectorType":$vector, "Value":$source, + "ValueRange":$indices, "Value":$padding, + CArg<"ArrayRef", "{}">:$maybeMasked)>, // Builder that sets permutation map (resp. padding) to // 'getMinorIdentityMap' (resp. zero). OpBuilderDAG<(ins "VectorType":$vector, "Value":$source, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a9a43e194d75..86f05c38ed89 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -448,9 +448,71 @@ Optional mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, } //----------------------------------------------------------------------------// -// Misc. conv vectorization patterns. +// Misc. vectorization patterns. //----------------------------------------------------------------------------// -// TODO: cleanup all this. + +/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and +/// TransferWriteOp. For now, this only applies when all low and high paddings +/// are determined to be zero. +LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite( + linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { + // Helper function to determine whether an OpFoldResult is not a zero Index. + auto isNotZeroIndex = [](OpFoldResult ofr) { + if (Attribute attr = ofr.dyn_cast()) + return attr.cast().getInt() != 0; + Value v = ofr.get(); + if (auto constOp = v.getDefiningOp()) + return constOp.getValue() != 0; + return true; + }; + + auto resultShapedType = padOp.result().getType().cast(); + // Bail on non-static shapes. + if (!resultShapedType.hasStaticShape()) + return failure(); + + // If any pad_low is not a static 0, needs a mask. Bail for now. + if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex)) + return failure(); + VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); + if (!vectorType) + return failure(); + + // Only support padding with a constant for now, i.e. either: + // 1. A BBarg from a different block. + // 2. A value defined outside of the current block. + Block &block = padOp.region().front(); + auto yieldOp = cast(block.getTerminator()); + assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); + Value padValue = yieldOp.values().front(); + Operation *definingOp = padValue.getDefiningOp(); + if (definingOp && definingOp->getBlock() == &block) + return failure(); + if (!definingOp && padValue.cast().getOwner() == &block) + return failure(); + + // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail. + if (llvm::any_of(padOp.getMixedHighPad(), + [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); })) + return failure(); + + // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] + + // TransferWriteOp@[0..0]. + SmallVector indices( + resultShapedType.getRank(), + rewriter.create(padOp.getLoc(), 0)); + Value read = rewriter.create( + padOp.getLoc(), vectorType, padOp.source(), indices, padValue); + Value init = + rewriter.create(padOp.getLoc(), resultShapedType.getShape(), + resultShapedType.getElementType()); + rewriter.replaceOpWithNewOp(padOp, read, init, + indices); + + return success(); +} + +// TODO: cleanup all the convolution vectorization patterns. template LogicalResult ConvOpVectorization::matchAndRewrite( ConvOp op, PatternRewriter &rewriter) const { diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 9fe8cf23c162..99b978895c7e 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1122,8 +1122,8 @@ public: } // namespace -void BroadcastOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { +void BroadcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { results.insert(context); } @@ -2026,17 +2026,32 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType, /// Builder that sets padding to zero. void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vector, Value source, ValueRange indices, - AffineMap permutationMap, + VectorType vectorType, Value source, + ValueRange indices, AffineMap permutationMap, ArrayRef maybeMasked) { Type elemType = source.getType().cast().getElementType(); Value padding = builder.create(result.location, elemType, builder.getZeroAttr(elemType)); if (maybeMasked.empty()) - return build(builder, result, vector, source, indices, permutationMap, + return build(builder, result, vectorType, source, indices, permutationMap, padding, ArrayAttr()); ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); - build(builder, result, vector, source, indices, permutationMap, padding, + build(builder, result, vectorType, source, indices, permutationMap, padding, + maskedArrayAttr); +} + +/// Builder that sets permutation map to 'getMinorIdentityMap'. +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, Value padding, + ArrayRef maybeMasked) { + auto permMap = getTransferMinorIdentityMap( + source.getType().cast(), vectorType); + if (maybeMasked.empty()) + return build(builder, result, vectorType, source, indices, permMap, padding, + ArrayAttr()); + ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); + build(builder, result, vectorType, source, indices, permMap, padding, maskedArrayAttr); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 3904353287c5..961a9307c1f5 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -390,3 +390,44 @@ func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12x outs(%c: memref<4x12xi32>) return } + +// ----- + +// CHECK-LABEL: func @pad_static +// CHECK-NOT: linalg.pad_tensor +func @pad_static(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] + // CHECK-SAME: : tensor, vector<2x3x4xf32> + // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> + // CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] + // CHECK-SAME: {masked = [false, false, false]} : vector<2x3x4xf32>, tensor<2x3x4xf32> + %0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + linalg.yield %pad_value : f32 + } : tensor to tensor<2x3x4xf32> + + // CHECK: return %[[WRITTEN]] : tensor<2x3x4xf32> + return %0 : tensor<2x3x4xf32> +} + +// CHECK-LABEL: func @pad_static_high_padding +// CHECK: linalg.pad_tensor +func @pad_static_high_padding(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { + %0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 1, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + linalg.yield %pad_value : f32 + } : tensor to tensor<2x3x4xf32> + return %0 : tensor<2x3x4xf32> +} + +// CHECK-LABEL: func @pad_dynamic +// CHECK: linalg.pad_tensor +func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, + %pad_value: f32) -> tensor<6x?x?x?xf32> { + %0 = linalg.pad_tensor %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + linalg.yield %pad_value : f32 + } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> + return %0 : tensor<6x?x?x?xf32> +} diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index f9dea42f3a8a..a492d496af51 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -491,6 +491,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) { patterns.insert( LinalgTransformationFilter() .addOpFilter()); + patterns.insert(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); }