diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h index 5af0da2f8528..b95b527d0639 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -65,10 +65,6 @@ void populateDistributeTransferWriteOpPatterns( /// region. void moveScalarUniformCode(WarpExecuteOnLane0Op op); -/// Collect patterns to propagate warp distribution. -void populatePropagateWarpVectorDistributionPatterns( - RewritePatternSet &pattern); - } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 7f8566aa6c47..586604f6fd6c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -11,7 +11,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" -#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/SideEffectUtils.h" using namespace mlir; @@ -182,60 +181,6 @@ static bool canBeHoisted(Operation *op, isSideEffectFree(op) && op->getNumRegions() == 0; } -/// Return a value yielded by `warpOp` which statifies the filter lamdba -/// condition and is not dead. -static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, - std::function fn) { - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - for (OpOperand &yieldOperand : yield->getOpOperands()) { - Value yieldValues = yieldOperand.get(); - Operation *definedOp = yieldValues.getDefiningOp(); - if (definedOp && fn(definedOp)) { - if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) - return &yieldOperand; - } - } - return {}; -} - -// Clones `op` into a new operation that takes `operands` and returns -// `resultTypes`. -static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, - Location loc, Operation *op, - ArrayRef operands, - ArrayRef resultTypes) { - OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, - op->getAttrs()); - return rewriter.create(res); -} - -/// Currently the distribution map is implicit based on the vector shape. In the -/// future it will be part of the op. -/// Example: -/// ``` -/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { -/// ... -/// vector.yield %3 : vector<32x16x64xf32> -/// } -/// ``` -/// Would have an implicit map of: -/// `(d0, d1, d2) -> (d0, d2)` -static AffineMap calculateImplicitMap(Value yield, Value ret) { - auto srcType = yield.getType().cast(); - auto dstType = ret.getType().cast(); - SmallVector perm; - // Check which dimensions of the yield value are different than the dimensions - // of the result to know the distributed dimensions. Then associate each - // distributed dimension to an ID in order. - for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { - if (srcType.getDimSize(i) != dstType.getDimSize(i)) - perm.push_back(getAffineDimExpr(i, yield.getContext())); - } - auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); - return map; -} - namespace { struct WarpOpToScfForPattern : public OpRewritePattern { @@ -405,322 +350,6 @@ private: DistributionMapFn distributionMapFn; }; -/// Sink out elementwise op feeding into a warp op yield. -/// ``` -/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { -/// ... -/// %3 = arith.addf %1, %2 : vector<32xf32> -/// vector.yield %3 : vector<32xf32> -/// } -/// ``` -/// To -/// ``` -/// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, -/// vector<1xf32>, vector<1xf32>) { -/// ... -/// %4 = arith.addf %2, %3 : vector<32xf32> -/// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, -/// vector<32xf32> -/// } -/// %0 = arith.addf %r#1, %r#2 : vector<1xf32> -struct WarpOpElementwise : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, - PatternRewriter &rewriter) const override { - OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { - return OpTrait::hasElementwiseMappableTraits(op); - }); - if (!yieldOperand) - return failure(); - Operation *elementWise = yieldOperand->get().getDefiningOp(); - unsigned operandIndex = yieldOperand->getOperandNumber(); - Value distributedVal = warpOp.getResult(operandIndex); - SmallVector yieldValues; - SmallVector retTypes; - Location loc = warpOp.getLoc(); - for (OpOperand &operand : elementWise->getOpOperands()) { - Type targetType; - if (auto vecType = distributedVal.getType().dyn_cast()) { - // If the result type is a vector, the operands must also be vectors. - auto operandType = operand.get().getType().cast(); - targetType = - VectorType::get(vecType.getShape(), operandType.getElementType()); - } else { - auto operandType = operand.get().getType(); - assert(!operandType.isa() && - "unexpected yield of vector from op with scalar result type"); - targetType = operandType; - } - retTypes.push_back(targetType); - yieldValues.push_back(operand.get()); - } - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, yieldValues, retTypes); - rewriter.setInsertionPointAfter(newWarpOp); - SmallVector newOperands(elementWise->getOperands().begin(), - elementWise->getOperands().end()); - for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { - newOperands[i] = newWarpOp.getResult(i + warpOp.getNumResults()); - } - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(newWarpOp); - Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, elementWise, newOperands, - {newWarpOp.getResult(operandIndex).getType()}); - newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); - return success(); - } -}; - -/// Sink out transfer_read op feeding into a warp op yield. -/// ``` -/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { -/// ... -// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, -// vector<32xf32> -/// vector.yield %2 : vector<32xf32> -/// } -/// ``` -/// To -/// ``` -/// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, -/// vector<1xf32>, vector<1xf32>) { -/// ... -/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, -/// vector<32xf32> vector.yield %2 : vector<32xf32> -/// } -/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> -struct WarpOpTransferRead : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, - PatternRewriter &rewriter) const override { - OpOperand *operand = getWarpResult( - warpOp, [](Operation *op) { return isa(op); }); - if (!operand) - return failure(); - auto read = operand->get().getDefiningOp(); - unsigned operandIndex = operand->getOperandNumber(); - Value distributedVal = warpOp.getResult(operandIndex); - - SmallVector indices(read.getIndices().begin(), - read.getIndices().end()); - AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); - AffineMap indexMap = map.compose(read.getPermutationMap()); - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(warpOp); - for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { - AffineExpr d0, d1; - bindDims(read.getContext(), d0, d1); - auto indexExpr = std::get<0>(it).dyn_cast(); - if (!indexExpr) - continue; - unsigned indexPos = indexExpr.getPosition(); - unsigned vectorPos = std::get<1>(it).cast().getPosition(); - int64_t scale = - distributedVal.getType().cast().getDimSize(vectorPos); - indices[indexPos] = - makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, - {indices[indexPos], warpOp.getLaneid()}); - } - Value newRead = rewriter.create( - read.getLoc(), distributedVal.getType(), read.getSource(), indices, - read.getPermutationMapAttr(), read.getPadding(), read.getMask(), - read.getInBoundsAttr()); - distributedVal.replaceAllUsesWith(newRead); - return success(); - } -}; - -/// Remove any result that has no use along with the matching yieldOp operand. -// TODO: Move this in WarpExecuteOnLane0Op canonicalization. -struct WarpOpDeadResult : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, - PatternRewriter &rewriter) const override { - SmallVector resultTypes; - SmallVector yieldValues; - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - for (OpResult result : warpOp.getResults()) { - if (result.use_empty()) - continue; - resultTypes.push_back(result.getType()); - yieldValues.push_back(yield.getOperand(result.getResultNumber())); - } - if (yield.getNumOperands() == yieldValues.size()) - return failure(); - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldValues, resultTypes); - unsigned resultIndex = 0; - for (OpResult result : warpOp.getResults()) { - if (result.use_empty()) - continue; - result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); - } - rewriter.eraseOp(warpOp); - return success(); - } -}; - -// If an operand is directly yielded out of the region we can forward it -// directly and it doesn't need to go through the region. -struct WarpOpForwardOperand : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, - PatternRewriter &rewriter) const override { - SmallVector resultTypes; - SmallVector yieldValues; - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - Value valForwarded; - unsigned resultIndex; - for (OpOperand &operand : yield->getOpOperands()) { - Value result = warpOp.getResult(operand.getOperandNumber()); - if (result.use_empty()) - continue; - - // Assume all the values coming from above are uniform. - if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { - if (result.getType() != operand.get().getType()) - continue; - valForwarded = operand.get(); - resultIndex = operand.getOperandNumber(); - break; - } - auto arg = operand.get().dyn_cast(); - if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) - continue; - Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; - if (result.getType() != warpOperand.getType()) - continue; - valForwarded = warpOperand; - resultIndex = operand.getOperandNumber(); - break; - } - if (!valForwarded) - return failure(); - warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); - return success(); - } -}; - -struct WarpOpBroadcast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, - PatternRewriter &rewriter) const override { - OpOperand *operand = getWarpResult( - warpOp, [](Operation *op) { return isa(op); }); - if (!operand) - return failure(); - unsigned int operandNumber = operand->getOperandNumber(); - auto broadcastOp = operand->get().getDefiningOp(); - Location loc = broadcastOp.getLoc(); - auto destVecType = - warpOp->getResultTypes()[operandNumber].cast(); - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {broadcastOp.getSource()}, - {broadcastOp.getSource().getType()}); - rewriter.setInsertionPointAfter(newWarpOp); - Value broadcasted = rewriter.create( - loc, destVecType, newWarpOp->getResults().back()); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); - - return success(); - } -}; - -/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if -/// the scf.ForOp is the last operation in the region so that it doesn't change -/// the order of execution. This creates a new scf.for region after the -/// WarpExecuteOnLane0Op. The new scf.for region will contain a new -/// WarpExecuteOnLane0Op region. Example: -/// ``` -/// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { -/// ... -/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) -/// -> (vector<128xf32>) { -/// ... -/// scf.yield %r : vector<128xf32> -/// } -/// vector.yield %v1 : vector<128xf32> -/// } -/// ``` -/// To: -/// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { -/// ... -/// vector.yield %v : vector<128xf32> -/// } -/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) -/// -> (vector<4xf32>) { -/// %iw = vector.warp_execute_on_lane_0(%laneid) -/// args(%varg : vector<4xf32>) -> (vector<4xf32>) { -/// ^bb0(%arg: vector<128xf32>): -/// ... -/// vector.yield %ir : vector<128xf32> -/// } -/// scf.yield %iw : vector<4xf32> -/// } -/// ``` -struct WarpOpScfForOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, - PatternRewriter &rewriter) const override { - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - // Only pick up forOp if it is the last op in the region. - Operation *lastNode = yield->getPrevNode(); - auto forOp = dyn_cast_or_null(lastNode); - if (!forOp) - return failure(); - SmallVector newOperands; - SmallVector resultIdx; - // Collect all the outputs coming from the forOp. - for (OpOperand &yieldOperand : yield->getOpOperands()) { - if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) - continue; - auto forResult = yieldOperand.get().cast(); - newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); - yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); - resultIdx.push_back(yieldOperand.getOperandNumber()); - } - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(warpOp); - // Create a new for op outside the region with a WarpExecuteOnLane0Op region - // inside. - auto newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newOperands); - rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); - auto innerWarp = rewriter.create( - warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), - warpOp.getWarpSize(), newForOp.getRegionIterArgs(), - forOp.getResultTypes()); - - SmallVector argMapping; - argMapping.push_back(newForOp.getInductionVar()); - for (Value args : innerWarp.getBody()->getArguments()) { - argMapping.push_back(args); - } - SmallVector yieldOperands; - for (Value operand : forOp.getBody()->getTerminator()->getOperands()) - yieldOperands.push_back(operand); - rewriter.eraseOp(forOp.getBody()->getTerminator()); - rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); - rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); - rewriter.create(innerWarp.getLoc(), yieldOperands); - rewriter.setInsertionPointAfter(innerWarp); - rewriter.create(forOp.getLoc(), innerWarp.getResults()); - rewriter.eraseOp(forOp); - // Replace the warpOp result coming from the original ForOp. - for (const auto &res : llvm::enumerate(resultIdx)) { - warpOp.getResult(res.value()) - .replaceAllUsesWith(newForOp.getResult(res.index())); - newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); - } - return success(); - } -}; - } // namespace void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( @@ -734,13 +363,6 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns( patterns.add(patterns.getContext(), distributionMapFn); } -void mlir::vector::populatePropagateWarpVectorDistributionPatterns( - RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); -} - void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { Block *body = warpOp.getBody(); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index b57791ad04a1..dc4dfee861fb 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s -// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s // CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3> // CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3> @@ -127,310 +126,4 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) { vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2> } return -} - -// ----- - -// CHECK-PROP-LABEL: func @warp_dead_result( -func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) { - // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) - %r:3 = vector.warp_execute_on_lane_0(%laneid)[32] -> - (vector<1xf32>, vector<1xf32>, vector<1xf32>) { - %2 = "some_def"() : () -> (vector<32xf32>) - %3 = "some_def"() : () -> (vector<32xf32>) - %4 = "some_def"() : () -> (vector<32xf32>) - // CHECK-PROP: vector.yield %{{.*}} : vector<32xf32> - vector.yield %2, %3, %4 : vector<32xf32>, vector<32xf32>, vector<32xf32> - } - // CHECK-PROP: return %[[R]] : vector<1xf32> - return %r#1 : vector<1xf32> -} - -// ----- - -// CHECK-PROP-LABEL: func @warp_propagate_operand( -// CHECK-PROP-SAME: %[[ID:.*]]: index, %[[V:.*]]: vector<4xf32>) -func.func @warp_propagate_operand(%laneid: index, %v0: vector<4xf32>) - -> (vector<4xf32>) { - %r = vector.warp_execute_on_lane_0(%laneid)[32] - args(%v0 : vector<4xf32>) -> (vector<4xf32>) { - ^bb0(%arg0 : vector<128xf32>) : - vector.yield %arg0 : vector<128xf32> - } - // CHECK-PROP: return %[[V]] : vector<4xf32> - return %r : vector<4xf32> -} - -// ----- - -#map0 = affine_map<()[s0] -> (s0 * 2)> - -// CHECK-PROP-LABEL: func @warp_propagate_elementwise( -func.func @warp_propagate_elementwise(%laneid: index, %dest: memref<1024xf32>) { - %c0 = arith.constant 0 : index - %c32 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - // CHECK-PROP: %[[R:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xf32>, vector<2xf32>, vector<2xf32>) - %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] -> - (vector<1xf32>, vector<2xf32>) { - // CHECK-PROP: %[[V0:.*]] = "some_def"() : () -> vector<32xf32> - // CHECK-PROP: %[[V1:.*]] = "some_def"() : () -> vector<32xf32> - // CHECK-PROP: %[[V2:.*]] = "some_def"() : () -> vector<64xf32> - // CHECK-PROP: %[[V3:.*]] = "some_def"() : () -> vector<64xf32> - // CHECK-PROP: vector.yield %[[V0]], %[[V1]], %[[V2]], %[[V3]] : vector<32xf32>, vector<32xf32>, vector<64xf32>, vector<64xf32> - %2 = "some_def"() : () -> (vector<32xf32>) - %3 = "some_def"() : () -> (vector<32xf32>) - %4 = "some_def"() : () -> (vector<64xf32>) - %5 = "some_def"() : () -> (vector<64xf32>) - %6 = arith.addf %2, %3 : vector<32xf32> - %7 = arith.addf %4, %5 : vector<64xf32> - vector.yield %6, %7 : vector<32xf32>, vector<64xf32> - } - // CHECK-PROP: %[[A0:.*]] = arith.addf %[[R]]#2, %[[R]]#3 : vector<2xf32> - // CHECK-PROP: %[[A1:.*]] = arith.addf %[[R]]#0, %[[R]]#1 : vector<1xf32> - %id2 = affine.apply #map0()[%laneid] - // CHECK-PROP: vector.transfer_write %[[A1]], {{.*}} : vector<1xf32>, memref<1024xf32> - // CHECK-PROP: vector.transfer_write %[[A0]], {{.*}} : vector<2xf32>, memref<1024xf32> - vector.transfer_write %r#0, %dest[%laneid] : vector<1xf32>, memref<1024xf32> - vector.transfer_write %r#1, %dest[%id2] : vector<2xf32>, memref<1024xf32> - return -} - -// ----- - -// CHECK-PROP-LABEL: func @warp_propagate_scalar_arith( -// CHECK-PROP: %[[r:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} { -// CHECK-PROP: %[[some_def0:.*]] = "some_def" -// CHECK-PROP: %[[some_def1:.*]] = "some_def" -// CHECK-PROP: vector.yield %[[some_def0]], %[[some_def1]] -// CHECK-PROP: } -// CHECK-PROP: arith.addf %[[r]]#0, %[[r]]#1 : f32 -func.func @warp_propagate_scalar_arith(%laneid: index) { - %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { - %0 = "some_def"() : () -> (f32) - %1 = "some_def"() : () -> (f32) - %2 = arith.addf %0, %1 : f32 - vector.yield %2 : f32 - } - vector.print %r : f32 - return -} - -// ----- - -// CHECK-PROP-LABEL: func @warp_propagate_cast( -// CHECK-PROP-NOT: vector.warp_execute_on_lane_0 -// CHECK-PROP: %[[result:.*]] = arith.sitofp %{{.*}} : i32 to f32 -// CHECK-PROP: return %[[result]] -func.func @warp_propagate_cast(%laneid : index, %i : i32) -> (f32) { - %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { - %casted = arith.sitofp %i : i32 to f32 - vector.yield %casted : f32 - } - return %r : f32 -} - -// ----- - -#map0 = affine_map<()[s0] -> (s0 * 2)> - -// CHECK-PROP-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)> - -// CHECK-PROP: func @warp_propagate_read -// CHECK-PROP-SAME: (%[[ID:.*]]: index -func.func @warp_propagate_read(%laneid: index, %src: memref<1024xf32>, %dest: memref<1024xf32>) { -// CHECK-PROP-NOT: warp_execute_on_lane_0 -// CHECK-PROP-DAG: %[[R0:.*]] = vector.transfer_read %arg1[%[[ID]]], %{{.*}} : memref<1024xf32>, vector<1xf32> -// CHECK-PROP-DAG: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID]]] -// CHECK-PROP-DAG: %[[R1:.*]] = vector.transfer_read %arg1[%[[ID2]]], %{{.*}} : memref<1024xf32>, vector<2xf32> -// CHECK-PROP: vector.transfer_write %[[R0]], {{.*}} : vector<1xf32>, memref<1024xf32> -// CHECK-PROP: vector.transfer_write %[[R1]], {{.*}} : vector<2xf32>, memref<1024xf32> - %c0 = arith.constant 0 : index - %c32 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] ->(vector<1xf32>, vector<2xf32>) { - %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<32xf32> - %3 = vector.transfer_read %src[%c32], %cst : memref<1024xf32>, vector<64xf32> - vector.yield %2, %3 : vector<32xf32>, vector<64xf32> - } - %id2 = affine.apply #map0()[%laneid] - vector.transfer_write %r#0, %dest[%laneid] : vector<1xf32>, memref<1024xf32> - vector.transfer_write %r#1, %dest[%id2] : vector<2xf32>, memref<1024xf32> - return -} - -// ----- - -// CHECK-PROP-LABEL: func @fold_vector_broadcast( -// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>) -// CHECK-PROP: %[[some_def:.*]] = "some_def" -// CHECK-PROP: vector.yield %[[some_def]] : vector<1xf32> -// CHECK-PROP: vector.print %[[r]] : vector<1xf32> -func.func @fold_vector_broadcast(%laneid: index) { - %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { - %0 = "some_def"() : () -> (vector<1xf32>) - %1 = vector.broadcast %0 : vector<1xf32> to vector<32xf32> - vector.yield %1 : vector<32xf32> - } - vector.print %r : vector<1xf32> - return -} - -// ----- - -// CHECK-PROP-LABEL: func @extract_vector_broadcast( -// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>) -// CHECK-PROP: %[[some_def:.*]] = "some_def" -// CHECK-PROP: vector.yield %[[some_def]] : vector<1xf32> -// CHECK-PROP: %[[broadcasted:.*]] = vector.broadcast %[[r]] : vector<1xf32> to vector<2xf32> -// CHECK-PROP: vector.print %[[broadcasted]] : vector<2xf32> -func.func @extract_vector_broadcast(%laneid: index) { - %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { - %0 = "some_def"() : () -> (vector<1xf32>) - %1 = vector.broadcast %0 : vector<1xf32> to vector<64xf32> - vector.yield %1 : vector<64xf32> - } - vector.print %r : vector<2xf32> - return -} - -// ----- - -// CHECK-PROP-LABEL: func @extract_scalar_vector_broadcast( -// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (f32) -// CHECK-PROP: %[[some_def:.*]] = "some_def" -// CHECK-PROP: vector.yield %[[some_def]] : f32 -// CHECK-PROP: %[[broadcasted:.*]] = vector.broadcast %[[r]] : f32 to vector<2xf32> -// CHECK-PROP: vector.print %[[broadcasted]] : vector<2xf32> -func.func @extract_scalar_vector_broadcast(%laneid: index) { - %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { - %0 = "some_def"() : () -> (f32) - %1 = vector.broadcast %0 : f32 to vector<64xf32> - vector.yield %1 : vector<64xf32> - } - vector.print %r : vector<2xf32> - return -} - -// ----- - -// CHECK-PROP-LABEL: func @warp_scf_for( -// CHECK-PROP: %[[INI:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>) { -// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> -// CHECK-PROP: vector.yield %[[INI1]] : vector<128xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]) -> (vector<4xf32>) { -// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]] : vector<4xf32>) -> (vector<4xf32>) { -// CHECK-PROP: ^bb0(%[[ARG:.*]]: vector<128xf32>): -// CHECK-PROP: %[[ACC:.*]] = "some_def"(%[[ARG]]) : (vector<128xf32>) -> vector<128xf32> -// CHECK-PROP: vector.yield %[[ACC]] : vector<128xf32> -// CHECK-PROP: } -// CHECK-PROP: scf.yield %[[W]] : vector<4xf32> -// CHECK-PROP: } -// CHECK-PROP: "some_use"(%[[F]]) : (vector<4xf32>) -> () -func.func @warp_scf_for(%arg0: index) { - %c128 = arith.constant 128 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) { - %ini = "some_def"() : () -> (vector<128xf32>) - %3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) { - %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>) - scf.yield %acc : vector<128xf32> - } - vector.yield %3 : vector<128xf32> - } - "some_use"(%0) : (vector<4xf32>) -> () - return -} - -// ----- - -// CHECK-PROP-LABEL: func @warp_scf_for_swap( -// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) { -// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> -// CHECK-PROP: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32> -// CHECK-PROP: vector.yield %[[INI1]], %[[INI2]] : vector<128xf32>, vector<128xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG1:.*]] = %[[INI]]#0, %[[FARG2:.*]] = %[[INI]]#1) -> (vector<4xf32>, vector<4xf32>) { -// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG1]], %[[FARG2]] : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) { -// CHECK-PROP: ^bb0(%[[ARG1:.*]]: vector<128xf32>, %[[ARG2:.*]]: vector<128xf32>): -// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%[[ARG1]]) : (vector<128xf32>) -> vector<128xf32> -// CHECK-PROP: %[[ACC2:.*]] = "some_def"(%[[ARG2]]) : (vector<128xf32>) -> vector<128xf32> -// CHECK-PROP: vector.yield %[[ACC2]], %[[ACC1]] : vector<128xf32>, vector<128xf32> -// CHECK-PROP: } -// CHECK-PROP: scf.yield %[[W]]#0, %[[W]]#1 : vector<4xf32>, vector<4xf32> -// CHECK-PROP: } -// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> () -// CHECK-PROP: "some_use"(%[[F]]#1) : (vector<4xf32>) -> () -func.func @warp_scf_for_swap(%arg0: index) { - %c128 = arith.constant 128 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0:2 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>) { - %ini1 = "some_def"() : () -> (vector<128xf32>) - %ini2 = "some_def"() : () -> (vector<128xf32>) - %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2) -> (vector<128xf32>, vector<128xf32>) { - %acc1 = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>) - %acc2 = "some_def"(%arg5) : (vector<128xf32>) -> (vector<128xf32>) - scf.yield %acc2, %acc1 : vector<128xf32>, vector<128xf32> - } - vector.yield %3#0, %3#1 : vector<128xf32>, vector<128xf32> - } - "some_use"(%0#0) : (vector<4xf32>) -> () - "some_use"(%0#1) : (vector<4xf32>) -> () - return -} - -// ----- - -#map = affine_map<()[s0] -> (s0 * 4)> -#map1 = affine_map<()[s0] -> (s0 * 128 + 128)> -#map2 = affine_map<()[s0] -> (s0 * 4 + 128)> - -// CHECK-PROP-LABEL: func @warp_scf_for_multiple_yield( -// CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) { -// CHECK-PROP-NEXT: "some_def"() : () -> vector<32xf32> -// CHECK-PROP-NEXT: vector.yield %{{.*}} : vector<32xf32> -// CHECK-PROP-NEXT: } -// CHECK-PROP-NOT: vector.warp_execute_on_lane_0 -// CHECK-PROP: vector.transfer_read {{.*}} : memref, vector<4xf32> -// CHECK-PROP: vector.transfer_read {{.*}} : memref, vector<4xf32> -// CHECK-PROP: %{{.*}}:2 = scf.for {{.*}} -> (vector<4xf32>, vector<4xf32>) { -// CHECK-PROP-NOT: vector.warp_execute_on_lane_0 -// CHECK-PROP: vector.transfer_read {{.*}} : memref, vector<4xf32> -// CHECK-PROP: vector.transfer_read {{.*}} : memref, vector<4xf32> -// CHECK-PROP: arith.addf {{.*}} : vector<4xf32> -// CHECK-PROP: arith.addf {{.*}} : vector<4xf32> -// CHECK-PROP: scf.yield {{.*}} : vector<4xf32>, vector<4xf32> -// CHECK-PROP: } -func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref, %arg2: memref) { - %c256 = arith.constant 256 : index - %c128 = arith.constant 128 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0:3 = vector.warp_execute_on_lane_0(%arg0)[32] -> - (vector<1xf32>, vector<4xf32>, vector<4xf32>) { - %def = "some_def"() : () -> (vector<32xf32>) - %r1 = vector.transfer_read %arg2[%c0], %cst {in_bounds = [true]} : memref, vector<128xf32> - %r2 = vector.transfer_read %arg2[%c128], %cst {in_bounds = [true]} : memref, vector<128xf32> - %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %r1, %arg5 = %r2) - -> (vector<128xf32>, vector<128xf32>) { - %o1 = affine.apply #map1()[%arg3] - %o2 = affine.apply #map2()[%arg3] - %4 = vector.transfer_read %arg1[%o1], %cst {in_bounds = [true]} : memref, vector<128xf32> - %5 = vector.transfer_read %arg1[%o2], %cst {in_bounds = [true]} : memref, vector<128xf32> - %6 = arith.addf %4, %arg4 : vector<128xf32> - %7 = arith.addf %5, %arg5 : vector<128xf32> - scf.yield %6, %7 : vector<128xf32>, vector<128xf32> - } - vector.yield %def, %3#0, %3#1 : vector<32xf32>, vector<128xf32>, vector<128xf32> - } - %1 = affine.apply #map()[%arg0] - vector.transfer_write %0#1, %arg2[%1] {in_bounds = [true]} : vector<4xf32>, memref - %2 = affine.apply #map2()[%arg0] - vector.transfer_write %0#2, %arg2[%2] {in_bounds = [true]} : vector<4xf32>, memref - "some_use"(%0#0) : (vector<1xf32>) -> () - return -} +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir index 159c677b9631..2205079f246a 100644 --- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir +++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir @@ -11,31 +11,6 @@ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \ // RUN: FileCheck %s -// Run the same test cases with distribution and propagation. -// RUN: mlir-opt %s -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" \ -// RUN: -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \ -// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \ -// RUN: -gpu-kernel-outlining \ -// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin)' \ -// RUN: -gpu-to-llvm -reconcile-unrealized-casts |\ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_cuda_runtime%shlibext \ -// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ -// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \ -// RUN: FileCheck %s - -// RUN: mlir-opt %s -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \ -// RUN: -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \ -// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \ -// RUN: -gpu-kernel-outlining \ -// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin)' \ -// RUN: -gpu-to-llvm -reconcile-unrealized-casts |\ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_cuda_runtime%shlibext \ -// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ -// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \ -// RUN: FileCheck %s - func.func @gpu_func(%arg1: memref<32xf32>, %arg2: memref<32xf32>) { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index cfdeb1e632e0..e1ffddc5f068 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -835,10 +835,6 @@ struct TestVectorDistribution llvm::cl::desc("Test hoist uniform"), llvm::cl::init(false)}; - Option propagateDistribution{ - *this, "propagate-distribution", - llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)}; - void runOnOperation() override { RewritePatternSet patterns(&getContext()); @@ -866,11 +862,7 @@ struct TestVectorDistribution populateDistributeTransferWriteOpPatterns(patterns, distributionFn); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } - if (propagateDistribution) { - RewritePatternSet patterns(ctx); - vector::populatePropagateWarpVectorDistributionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } + WarpExecuteOnLane0LoweringOptions options; options.warpAllocationFn = allocateGlobalSharedMemory; options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,