mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-13 19:24:21 +00:00
[mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (#66648)
…cast) expansion This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of vector operations comprising `shuffle` and `bitwise` ops. Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect. The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori` with shifts`. The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM.
This commit is contained in:
parent
2a38d83918
commit
04ba475e85
@ -300,6 +300,8 @@ def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
|
||||
|
||||
This is usually a late step that is run after bufferization as part of the
|
||||
process of lowering to e.g. LLVM or NVVM.
|
||||
|
||||
Warning: these patterns currently only work for little endian targets.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
@ -23,6 +23,7 @@ namespace mlir {
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace arith {
|
||||
class AndIOp;
|
||||
class NarrowTypeEmulationConverter;
|
||||
class TruncIOp;
|
||||
} // namespace arith
|
||||
@ -304,13 +305,22 @@ void populateVectorNarrowTypeEmulationPatterns(
|
||||
|
||||
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
|
||||
/// vector operations comprising `shuffle` and `bitwise` ops.
|
||||
/// Warning: these patterns currently only work for little endian targets.
|
||||
FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
|
||||
vector::BitCastOp bitCastOp,
|
||||
arith::TruncIOp truncOp,
|
||||
vector::BroadcastOp maybeBroadcastOp);
|
||||
|
||||
/// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
|
||||
/// vector operations comprising `shuffle` and `bitwise` ops.
|
||||
/// Warning: these patterns currently only work for little endian targets.
|
||||
FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
|
||||
vector::BitCastOp bitCastOp,
|
||||
vector::BroadcastOp maybeBroadcastOp);
|
||||
|
||||
/// Appends patterns for rewriting vector operations over narrow types with
|
||||
/// ops over wider types.
|
||||
/// Warning: these patterns currently only work for little endian targets.
|
||||
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
|
@ -224,6 +224,106 @@ struct BitCastBitsEnumerator {
|
||||
SmallVector<SourceElementRangeList> sourceElementRanges;
|
||||
};
|
||||
|
||||
/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
|
||||
/// advantage of high-level information to avoid leaving LLVM to scramble with
|
||||
/// peephole optimizations.
|
||||
/// BitCastBitsEnumerator encodes for each element of the target vector the
|
||||
/// provenance of the bits in the source vector. We can "transpose" this
|
||||
/// information to build a sequence of shuffles and bitwise ops that will
|
||||
/// produce the desired result.
|
||||
//
|
||||
/// Consider the following motivating example:
|
||||
/// ```
|
||||
/// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
|
||||
/// ```
|
||||
//
|
||||
/// BitCastBitsEnumerator contains the following information:
|
||||
/// ```
|
||||
/// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
|
||||
/// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
|
||||
/// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
|
||||
/// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
|
||||
/// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
|
||||
/// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
|
||||
/// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
|
||||
/// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
|
||||
/// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
|
||||
/// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
|
||||
/// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
|
||||
/// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
|
||||
/// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
|
||||
/// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
|
||||
/// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
|
||||
/// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
|
||||
/// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
|
||||
/// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
|
||||
/// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
|
||||
/// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
|
||||
/// ```
|
||||
///
|
||||
/// In the above, each row represents one target vector element and each
|
||||
/// column represents one bit contribution from a source vector element.
|
||||
/// The algorithm creates vector.shuffle operations (in this case there are 3
|
||||
/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
|
||||
/// algorithm populates the bits as follows:
|
||||
/// ```
|
||||
/// src bits 0 ...
|
||||
/// 1st shuffle |xxxxx |xx |...
|
||||
/// 2nd shuffle | xxx| xxxxx |...
|
||||
/// 3rd shuffle | | x|...
|
||||
/// ```
|
||||
//
|
||||
/// The algorithm proceeds as follows:
|
||||
/// 1. for each vector.shuffle, collect the source vectors that participate in
|
||||
/// this shuffle. One source vector per target element of the resulting
|
||||
/// vector.shuffle. If there is no source element contributing bits for the
|
||||
/// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
|
||||
/// 2 columns).
|
||||
/// 2. represent the bitrange in the source vector as a mask. If there is no
|
||||
/// source element contributing bits for the current vector.shuffle, take 0.
|
||||
/// 3. shift right by the proper amount to align the source bitrange at
|
||||
/// position 0. This is exactly the low end of the bitrange. For instance,
|
||||
/// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
|
||||
/// shift right by 3 to get the bits contributed by the source element #1
|
||||
/// into position 0.
|
||||
/// 4. shift left by the proper amount to to align to the desired position in
|
||||
/// the result element vector. For instance, the contribution of the second
|
||||
/// source element for the first row needs to be shifted by `5` to form the
|
||||
/// first i8 result element.
|
||||
///
|
||||
/// Eventually, we end up building the sequence
|
||||
/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
|
||||
/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
|
||||
/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
|
||||
struct BitCastRewriter {
|
||||
/// Helper metadata struct to hold the static quantities for the rewrite.
|
||||
struct Metadata {
|
||||
SmallVector<int64_t> shuffles;
|
||||
SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
|
||||
};
|
||||
|
||||
BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
|
||||
|
||||
/// Verify that the preconditions for the rewrite are met.
|
||||
LogicalResult precondition(PatternRewriter &rewriter,
|
||||
VectorType preconditionVectorType, Operation *op);
|
||||
|
||||
/// Precompute the metadata for the rewrite.
|
||||
SmallVector<BitCastRewriter::Metadata>
|
||||
precomputeMetadata(IntegerType shuffledElementType);
|
||||
|
||||
/// Rewrite one step of the sequence:
|
||||
/// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
|
||||
Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
|
||||
Value runningResult,
|
||||
const BitCastRewriter::Metadata &metadata);
|
||||
|
||||
private:
|
||||
/// Underlying enumerator that encodes the provenance of the bits in the each
|
||||
/// element of the result vector.
|
||||
BitCastBitsEnumerator enumerator;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
[[maybe_unused]] static raw_ostream &operator<<(raw_ostream &os,
|
||||
@ -256,7 +356,7 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
|
||||
LDBG("targetVectorType: " << targetVectorType);
|
||||
|
||||
int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
|
||||
(void) mostMinorSourceDim;
|
||||
(void)mostMinorSourceDim;
|
||||
assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
|
||||
"source and target bitwidths must match");
|
||||
|
||||
@ -275,79 +375,107 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
|
||||
}
|
||||
}
|
||||
|
||||
BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
|
||||
VectorType targetVectorType)
|
||||
: enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
|
||||
LDBG("\n" << enumerator.sourceElementRanges);
|
||||
}
|
||||
|
||||
LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
|
||||
VectorType precondition,
|
||||
Operation *op) {
|
||||
if (precondition.getRank() != 1 || precondition.isScalable())
|
||||
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
|
||||
|
||||
// TODO: consider relaxing this restriction in the future if we find ways
|
||||
// to really work with subbyte elements across the MLIR/LLVM boundary.
|
||||
int64_t resultBitwidth = precondition.getElementTypeBitWidth();
|
||||
if (resultBitwidth % 8 != 0)
|
||||
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<BitCastRewriter::Metadata>
|
||||
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
|
||||
SmallVector<BitCastRewriter::Metadata> result;
|
||||
for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
|
||||
shuffleIdx < e; ++shuffleIdx) {
|
||||
SmallVector<int64_t> shuffles;
|
||||
SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
|
||||
|
||||
// Create the attribute quantities for the shuffle / mask / shift ops.
|
||||
for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
|
||||
int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
|
||||
? srcEltRangeList[shuffleIdx].sourceElementIdx
|
||||
: 0;
|
||||
shuffles.push_back(sourceElement);
|
||||
|
||||
int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
|
||||
? srcEltRangeList[shuffleIdx].sourceBitBegin
|
||||
: 0;
|
||||
int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
|
||||
? srcEltRangeList[shuffleIdx].sourceBitEnd
|
||||
: 0;
|
||||
IntegerAttr mask = IntegerAttr::get(
|
||||
shuffledElementType,
|
||||
llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
|
||||
bitLo, bitHi));
|
||||
masks.push_back(mask);
|
||||
|
||||
int64_t shiftRight = bitLo;
|
||||
shiftRightAmounts.push_back(
|
||||
IntegerAttr::get(shuffledElementType, shiftRight));
|
||||
|
||||
int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
|
||||
shiftLeftAmounts.push_back(
|
||||
IntegerAttr::get(shuffledElementType, shiftLeft));
|
||||
}
|
||||
|
||||
result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
|
||||
Value initialValue, Value runningResult,
|
||||
const BitCastRewriter::Metadata &metadata) {
|
||||
// Create vector.shuffle from the metadata.
|
||||
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
|
||||
loc, initialValue, initialValue, metadata.shuffles);
|
||||
|
||||
// Intersect with the mask.
|
||||
VectorType shuffledVectorType = shuffleOp.getResultVectorType();
|
||||
auto constOp = rewriter.create<arith::ConstantOp>(
|
||||
loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
|
||||
Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
|
||||
|
||||
// Align right on 0.
|
||||
auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
|
||||
Value shiftedRight =
|
||||
rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
|
||||
|
||||
// Shift bits left into their final position.
|
||||
auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
|
||||
Value shiftedLeft =
|
||||
rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
|
||||
|
||||
runningResult =
|
||||
runningResult
|
||||
? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
|
||||
: shiftedLeft;
|
||||
|
||||
return runningResult;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
|
||||
/// advantage of high-level information to avoid leaving LLVM to scramble with
|
||||
/// peephole optimizations.
|
||||
|
||||
// BitCastBitsEnumerator encodes for each element of the target vector the
|
||||
// provenance of the bits in the source vector. We can "transpose" this
|
||||
// information to build a sequence of shuffles and bitwise ops that will
|
||||
// produce the desired result.
|
||||
//
|
||||
// Let's take the following motivating example to explain the algorithm:
|
||||
// ```
|
||||
// %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
|
||||
// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
|
||||
// ```
|
||||
//
|
||||
// BitCastBitsEnumerator contains the following information:
|
||||
// ```
|
||||
// { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 }
|
||||
// { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 }
|
||||
// { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 }
|
||||
// { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 }
|
||||
// { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 }
|
||||
// { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 }
|
||||
// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 }
|
||||
// { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 }
|
||||
// { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 }
|
||||
// { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
|
||||
// { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
|
||||
// { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
|
||||
// { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
|
||||
// { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6}
|
||||
// { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 }
|
||||
// { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 }
|
||||
// { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 }
|
||||
// { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
|
||||
// { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
|
||||
// { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 }
|
||||
// ```
|
||||
//
|
||||
// In the above, each row represents one target vector element and each
|
||||
// column represents one bit contribution from a source vector element.
|
||||
// The algorithm creates vector.shuffle operations (in this case there are 3
|
||||
// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
|
||||
// algorithm populates the bits as follows:
|
||||
// ```
|
||||
// src bits 0 ...
|
||||
// 1st shuffle |xxxxx |xx |...
|
||||
// 2nd shuffle | xxx| xxxxx |...
|
||||
// 3rd shuffle | | x|...
|
||||
// ```
|
||||
//
|
||||
// The algorithm proceeds as follows:
|
||||
// 1. for each vector.shuffle, collect the source vectors that participate in
|
||||
// this shuffle. One source vector per target element of the resulting
|
||||
// vector.shuffle. If there is no source element contributing bits for the
|
||||
// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
|
||||
// 2 columns).
|
||||
// 2. represent the bitrange in the source vector as a mask. If there is no
|
||||
// source element contributing bits for the current vector.shuffle, take 0.
|
||||
// 3. shift right by the proper amount to align the source bitrange at
|
||||
// position 0. This is exactly the low end of the bitrange. For instance,
|
||||
// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
|
||||
// shift right by 3 to get the bits contributed by the source element #1
|
||||
// into position 0.
|
||||
// 4. shift left by the proper amount to to align to the desired position in
|
||||
// the result element vector. For instance, the contribution of the second
|
||||
// source element for the first row needs to be shifted by `5` to form the
|
||||
// first i8 result element.
|
||||
// Eventually, we end up building the sequence
|
||||
// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update the
|
||||
// result vector (i.e. the `shiftright -> shiftleft -> or` part) with the bits
|
||||
// extracted from the source vector (i.e. the `shuffle -> and` part).
|
||||
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@ -359,93 +487,93 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
|
||||
if (!truncOp)
|
||||
return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
|
||||
|
||||
VectorType targetVectorType = bitCastOp.getResultVectorType();
|
||||
if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
|
||||
return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
|
||||
// TODO: consider relaxing this restriction in the future if we find ways
|
||||
// to really work with subbyte elements across the MLIR/LLVM boundary.
|
||||
int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
|
||||
if (resultBitwidth % 8 != 0)
|
||||
return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
|
||||
|
||||
// Set up the BitCastRewriter and verify the precondition.
|
||||
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
|
||||
BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
|
||||
LDBG("\n" << be.sourceElementRanges);
|
||||
VectorType targetVectorType = bitCastOp.getResultVectorType();
|
||||
BitCastRewriter bcr(sourceVectorType, targetVectorType);
|
||||
if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
|
||||
return failure();
|
||||
|
||||
Value initialValue = truncOp.getIn();
|
||||
auto initalVectorType = initialValue.getType().cast<VectorType>();
|
||||
auto initalElementType = initalVectorType.getElementType();
|
||||
auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
|
||||
|
||||
Value res;
|
||||
for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
|
||||
++shuffleIdx) {
|
||||
SmallVector<int64_t> shuffles;
|
||||
SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
|
||||
|
||||
// Create the attribute quantities for the shuffle / mask / shift ops.
|
||||
for (auto &srcEltRangeList : be.sourceElementRanges) {
|
||||
bool idxContributesBits =
|
||||
(shuffleIdx < (int64_t)srcEltRangeList.size());
|
||||
int64_t sourceElementIdx =
|
||||
idxContributesBits ? srcEltRangeList[shuffleIdx].sourceElementIdx
|
||||
: 0;
|
||||
shuffles.push_back(sourceElementIdx);
|
||||
|
||||
int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
|
||||
? srcEltRangeList[shuffleIdx].sourceBitBegin
|
||||
: 0;
|
||||
int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
|
||||
? srcEltRangeList[shuffleIdx].sourceBitEnd
|
||||
: 0;
|
||||
IntegerAttr mask = IntegerAttr::get(
|
||||
rewriter.getIntegerType(initalElementBitWidth),
|
||||
llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi));
|
||||
masks.push_back(mask);
|
||||
|
||||
int64_t shiftRight = bitLo;
|
||||
shiftRightAmounts.push_back(IntegerAttr::get(
|
||||
rewriter.getIntegerType(initalElementBitWidth), shiftRight));
|
||||
|
||||
int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
|
||||
shiftLeftAmounts.push_back(IntegerAttr::get(
|
||||
rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
|
||||
}
|
||||
|
||||
// Create vector.shuffle #shuffleIdx.
|
||||
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
|
||||
bitCastOp.getLoc(), initialValue, initialValue, shuffles);
|
||||
// And with the mask.
|
||||
VectorType vt = VectorType::Builder(initalVectorType)
|
||||
.setDim(initalVectorType.getRank() - 1, masks.size());
|
||||
auto constOp = rewriter.create<arith::ConstantOp>(
|
||||
bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
|
||||
Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
|
||||
shuffleOp, constOp);
|
||||
// Align right on 0.
|
||||
auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
|
||||
bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
|
||||
Value shiftedRight = rewriter.create<arith::ShRUIOp>(
|
||||
bitCastOp.getLoc(), andValue, shiftRightConstantOp);
|
||||
|
||||
auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
|
||||
bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts));
|
||||
Value shiftedLeft = rewriter.create<arith::ShLIOp>(
|
||||
bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp);
|
||||
|
||||
res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res,
|
||||
shiftedLeft)
|
||||
: shiftedLeft;
|
||||
// Perform the rewrite.
|
||||
Value truncValue = truncOp.getIn();
|
||||
auto shuffledElementType =
|
||||
cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
|
||||
Value runningResult;
|
||||
for (const BitCastRewriter ::Metadata &metadata :
|
||||
bcr.precomputeMetadata(shuffledElementType)) {
|
||||
runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
|
||||
runningResult, metadata);
|
||||
}
|
||||
|
||||
bool narrowing = resultBitwidth <= initalElementBitWidth;
|
||||
// Finalize the rewrite.
|
||||
bool narrowing = targetVectorType.getElementTypeBitWidth() <=
|
||||
shuffledElementType.getIntOrFloatBitWidth();
|
||||
if (narrowing) {
|
||||
rewriter.replaceOpWithNewOp<arith::TruncIOp>(
|
||||
bitCastOp, bitCastOp.getResultVectorType(), res);
|
||||
bitCastOp, bitCastOp.getResultVectorType(), runningResult);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
|
||||
bitCastOp, bitCastOp.getResultVectorType(), res);
|
||||
bitCastOp, bitCastOp.getResultVectorType(), runningResult);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RewriteExtOfBitCast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
|
||||
/// take advantage of high-level information to avoid leaving LLVM to scramble
|
||||
/// with peephole optimizations.
|
||||
template <typename ExtOpType>
|
||||
struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
|
||||
using OpRewritePattern<ExtOpType>::OpRewritePattern;
|
||||
|
||||
RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
|
||||
: OpRewritePattern<ExtOpType>(context, benefit) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ExtOpType extOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// The source must be a bitcast op.
|
||||
auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
|
||||
if (!bitCastOp)
|
||||
return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
|
||||
|
||||
// Set up the BitCastRewriter and verify the precondition.
|
||||
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
|
||||
VectorType targetVectorType = bitCastOp.getResultVectorType();
|
||||
BitCastRewriter bcr(sourceVectorType, targetVectorType);
|
||||
if (failed(bcr.precondition(
|
||||
rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
|
||||
return failure();
|
||||
|
||||
// Perform the rewrite.
|
||||
Value runningResult;
|
||||
Value sourceValue = bitCastOp.getSource();
|
||||
auto shuffledElementType =
|
||||
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
|
||||
for (const BitCastRewriter::Metadata &metadata :
|
||||
bcr.precomputeMetadata(shuffledElementType)) {
|
||||
runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
|
||||
sourceValue, runningResult, metadata);
|
||||
}
|
||||
|
||||
// Finalize the rewrite.
|
||||
bool narrowing =
|
||||
cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
|
||||
shuffledElementType.getIntOrFloatBitWidth();
|
||||
if (narrowing) {
|
||||
rewriter.replaceOpWithNewOp<arith::TruncIOp>(
|
||||
extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<ExtOpType>(
|
||||
extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -466,5 +594,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
|
||||
|
||||
void vector::populateVectorNarrowTypeRewritePatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<RewriteBitCastOfTruncI>(patterns.getContext(), benefit);
|
||||
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
|
||||
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
|
||||
benefit);
|
||||
}
|
||||
|
@ -146,6 +146,53 @@ func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
|
||||
return %1 : vector<8xi6>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @f1ext(
|
||||
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> {
|
||||
func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
|
||||
// CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, -32, 124, -128, -16, 62, -64, -8]> : vector<8xi8>
|
||||
// CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[0, 3, 0, 15, 1, 0, 7, 0]> : vector<8xi8>
|
||||
// CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 5, 2, 7, 4, 1, 6, 3]> : vector<8xi8>
|
||||
// CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 3, 5, 1, 4, 5, 2, 5]> : vector<8xi8>
|
||||
// CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1, 2, 3, 3, 4] : vector<5xi8>, vector<5xi8>
|
||||
// CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<8xi8>
|
||||
// CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<8xi8>
|
||||
// CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 0, 2, 3, 0, 4, 0] : vector<5xi8>, vector<5xi8>
|
||||
// CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<8xi8>
|
||||
// CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<8xi8>
|
||||
// CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<8xi8>
|
||||
// CHECK: %[[RES:.*]] = arith.extsi %[[O1]] : vector<8xi8> to vector<8xi16>
|
||||
// return %[[RES]] : vector<8xi16>
|
||||
|
||||
%0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
|
||||
%1 = arith.extsi %0 : vector<8xi5> to vector<8xi16>
|
||||
return %1 : vector<8xi16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @f2ext(
|
||||
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> {
|
||||
func.func @f2ext(%a: vector<5xi8>) -> vector<8xi16> {
|
||||
// CHECK-NOT: arith.extsi {{.*}} : vector<8xi8> to vector<8xi16>
|
||||
// CHECK: %[[RES:.*]] = arith.extui {{.*}} : vector<8xi8> to vector<8xi16>
|
||||
// return %[[RES]] : vector<8xi16>
|
||||
|
||||
%0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
|
||||
%1 = arith.extui %0 : vector<8xi5> to vector<8xi16>
|
||||
return %1 : vector<8xi16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @f3ext(
|
||||
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi17> {
|
||||
func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
|
||||
// CHECK: bitcast
|
||||
// CHECK: extsi
|
||||
// CHECK-NOT: shuffle
|
||||
// CHECK-NOT: andi
|
||||
// CHECK-NOT: ori
|
||||
%0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
|
||||
%1 = arith.extsi %0 : vector<8xi5> to vector<8xi17>
|
||||
return %1 : vector<8xi17>
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%module_op: !transform.any_op):
|
||||
%f = transform.structured.match ops{["func.func"]} in %module_op
|
||||
|
@ -124,6 +124,47 @@ func.func @f3(%v: vector<2xi48>) {
|
||||
return
|
||||
}
|
||||
|
||||
func.func @print_as_i1_8xi5(%v : vector<8xi5>) {
|
||||
%bitsi40 = vector.bitcast %v : vector<8xi5> to vector<40xi1>
|
||||
vector.print %bitsi40 : vector<40xi1>
|
||||
return
|
||||
}
|
||||
|
||||
func.func @print_as_i1_8xi16(%v : vector<8xi16>) {
|
||||
%bitsi128 = vector.bitcast %v : vector<8xi16> to vector<128xi1>
|
||||
vector.print %bitsi128 : vector<128xi1>
|
||||
return
|
||||
}
|
||||
|
||||
func.func @fext(%a: vector<5xi8>) {
|
||||
%0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
|
||||
func.call @print_as_i1_8xi5(%0) : (vector<8xi5>) -> ()
|
||||
// CHECK: (
|
||||
// CHECK-SAME: 1, 1, 1, 1, 0,
|
||||
// CHECK-SAME: 1, 1, 1, 0, 1,
|
||||
// CHECK-SAME: 1, 1, 0, 1, 1,
|
||||
// CHECK-SAME: 1, 1, 0, 1, 1,
|
||||
// CHECK-SAME: 0, 1, 1, 1, 0,
|
||||
// CHECK-SAME: 0, 1, 1, 0, 1,
|
||||
// CHECK-SAME: 1, 1, 1, 1, 0,
|
||||
// CHECK-SAME: 1, 0, 1, 1, 1 )
|
||||
|
||||
%1 = arith.extui %0 : vector<8xi5> to vector<8xi16>
|
||||
func.call @print_as_i1_8xi16(%1) : (vector<8xi16>) -> ()
|
||||
// CHECK: (
|
||||
// CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// CHECK-SAME: 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// CHECK-SAME: 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// CHECK-SAME: 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// CHECK-SAME: 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// CHECK-SAME: 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// CHECK-SAME: 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
func.func @entry() {
|
||||
%v = arith.constant dense<[
|
||||
0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8,
|
||||
@ -141,6 +182,11 @@ func.func @entry() {
|
||||
]> : vector<2xi48>
|
||||
func.call @f3(%v3) : (vector<2xi48>) -> ()
|
||||
|
||||
%v4 = arith.constant dense<[
|
||||
0xef, 0xee, 0xed, 0xec, 0xeb
|
||||
]> : vector<5xi8>
|
||||
func.call @fext(%v4) : (vector<5xi8>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user