Add narrow type emulation pattern for vector.transfer_read

Reviewed By: mravishankar, hanchung

Differential Revision: https://reviews.llvm.org/D158757
This commit is contained in:
yzhang93 2023-08-29 13:14:47 -07:00 committed by Hanhan Wang
parent 0863051208
commit f4bef787bc
2 changed files with 97 additions and 10 deletions

View File

@ -36,9 +36,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto sourceType = cast<MemRefType>(adaptor.getBase().getType());
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = sourceType.getElementType();
Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
@ -81,16 +81,73 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
auto srcElementType = sourceType.getElementType();
auto numElements =
static_cast<int>(std::ceil(static_cast<double>(origElements) / scale));
auto numElements = (origElements + scale - 1) / scale;
auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numElements, srcElementType), adaptor.getBase(),
loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
numElements *= scale;
auto castType = VectorType::get(numElements, oldElementType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, castType, newLoad);
auto bitCast =
rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
rewriter.replaceOp(op, bitCast->getResult(0));
return success();
}
};
//===----------------------------------------------------------------------===//
// ConvertVectorTransferRead
//===----------------------------------------------------------------------===//
struct ConvertVectorTransferRead final
: OpConversionPattern<vector::TransferReadOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}
int scale = dstBits / srcBits;
auto origElements = op.getVectorType().getNumElements();
if (origElements % scale != 0)
return failure();
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
adaptor.getPadding());
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
auto numElements = (origElements + scale - 1) / scale;
auto newReadType = VectorType::get(numElements, newElementType);
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, newReadType, adaptor.getSource(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
newPadding);
auto bitCast =
rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
rewriter.replaceOp(op, bitCast->getResult(0));
return success();
@ -107,5 +164,6 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `vector.*` conversion patterns.
patterns.add<ConvertVectorLoad>(typeConverter, patterns.getContext());
patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
}

View File

@ -79,3 +79,32 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
// -----
func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
%c0 = arith.constant 0 : i4
%0 = memref.alloc() : memref<3x8xi4>
%1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
memref<3x8xi4>, vector<8xi4>
return %1 : vector<8xi4>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
// CHECK: func @vector_transfer_read_i4
// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
// CHECK: %[[CONST:.+]] = arith.constant 0 : i4
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
// CHECK: %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i8
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
// CHECK32: func @vector_transfer_read_i4
// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
// CHECK32: %[[CONST:.+]] = arith.constant 0 : i4
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
// CHECK32: %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i32
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>