mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-15 04:00:56 +00:00
Add narrow type emulation pattern for vector.transfer_read
Reviewed By: mravishankar, hanchung Differential Revision: https://reviews.llvm.org/D158757
This commit is contained in:
parent
0863051208
commit
f4bef787bc
@ -36,9 +36,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto sourceType = cast<MemRefType>(adaptor.getBase().getType());
|
||||
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
|
||||
Type oldElementType = op.getType().getElementType();
|
||||
Type newElementType = sourceType.getElementType();
|
||||
Type newElementType = convertedType.getElementType();
|
||||
int srcBits = oldElementType.getIntOrFloatBitWidth();
|
||||
int dstBits = newElementType.getIntOrFloatBitWidth();
|
||||
|
||||
@ -81,16 +81,73 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
|
||||
stridedMetadata.getConstifiedMixedStrides(),
|
||||
getAsOpFoldResult(adaptor.getIndices()));
|
||||
|
||||
auto srcElementType = sourceType.getElementType();
|
||||
auto numElements =
|
||||
static_cast<int>(std::ceil(static_cast<double>(origElements) / scale));
|
||||
auto numElements = (origElements + scale - 1) / scale;
|
||||
auto newLoad = rewriter.create<vector::LoadOp>(
|
||||
loc, VectorType::get(numElements, srcElementType), adaptor.getBase(),
|
||||
loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
|
||||
|
||||
numElements *= scale;
|
||||
auto castType = VectorType::get(numElements, oldElementType);
|
||||
auto bitCast = rewriter.create<vector::BitCastOp>(loc, castType, newLoad);
|
||||
auto bitCast =
|
||||
rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
|
||||
|
||||
rewriter.replaceOp(op, bitCast->getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvertVectorTransferRead
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ConvertVectorTransferRead final
|
||||
: OpConversionPattern<vector::TransferReadOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
|
||||
Type oldElementType = op.getType().getElementType();
|
||||
Type newElementType = convertedType.getElementType();
|
||||
int srcBits = oldElementType.getIntOrFloatBitWidth();
|
||||
int dstBits = newElementType.getIntOrFloatBitWidth();
|
||||
|
||||
if (dstBits % srcBits != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only dstBits % srcBits == 0 supported");
|
||||
}
|
||||
int scale = dstBits / srcBits;
|
||||
|
||||
auto origElements = op.getVectorType().getNumElements();
|
||||
if (origElements % scale != 0)
|
||||
return failure();
|
||||
|
||||
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
|
||||
adaptor.getPadding());
|
||||
|
||||
auto stridedMetadata =
|
||||
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
|
||||
|
||||
OpFoldResult linearizedIndices;
|
||||
std::tie(std::ignore, linearizedIndices) =
|
||||
memref::getLinearizedMemRefOffsetAndSize(
|
||||
rewriter, loc, srcBits, dstBits,
|
||||
stridedMetadata.getConstifiedMixedOffset(),
|
||||
stridedMetadata.getConstifiedMixedSizes(),
|
||||
stridedMetadata.getConstifiedMixedStrides(),
|
||||
getAsOpFoldResult(adaptor.getIndices()));
|
||||
|
||||
auto numElements = (origElements + scale - 1) / scale;
|
||||
auto newReadType = VectorType::get(numElements, newElementType);
|
||||
|
||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||
loc, newReadType, adaptor.getSource(),
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
|
||||
newPadding);
|
||||
|
||||
auto bitCast =
|
||||
rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
|
||||
|
||||
rewriter.replaceOp(op, bitCast->getResult(0));
|
||||
return success();
|
||||
@ -107,5 +164,6 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
|
||||
// Populate `vector.*` conversion patterns.
|
||||
patterns.add<ConvertVectorLoad>(typeConverter, patterns.getContext());
|
||||
patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
@ -79,3 +79,32 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
|
||||
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
|
||||
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
|
||||
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
|
||||
%c0 = arith.constant 0 : i4
|
||||
%0 = memref.alloc() : memref<3x8xi4>
|
||||
%1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
|
||||
memref<3x8xi4>, vector<8xi4>
|
||||
return %1 : vector<8xi4>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
|
||||
// CHECK: func @vector_transfer_read_i4
|
||||
// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
|
||||
// CHECK: %[[CONST:.+]] = arith.constant 0 : i4
|
||||
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
|
||||
// CHECK: %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i8
|
||||
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
|
||||
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
|
||||
|
||||
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
|
||||
// CHECK32: func @vector_transfer_read_i4
|
||||
// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
|
||||
// CHECK32: %[[CONST:.+]] = arith.constant 0 : i4
|
||||
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
|
||||
// CHECK32: %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i32
|
||||
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
|
||||
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
|
||||
|
Loading…
Reference in New Issue
Block a user