[mlir][spirv] add support lowering of extract_slice to scalar type

Differential Revision: https://reviews.llvm.org/D102041
This commit is contained in:
thomasraoux 2021-05-06 16:41:43 -07:00
parent bc302bfbef
commit 565ee6afc7
2 changed files with 9 additions and 3 deletions

View File

@ -113,9 +113,6 @@ struct VectorExtractStridedSliceOpConvert final
if (!dstType)
return failure();
// Extract vector<1xT> not supported yet.
if (dstType.isa<spirv::ScalarType>())
return failure();
uint64_t offset = getFirstIntValue(extractOp.offsets());
uint64_t size = getFirstIntValue(extractOp.sizes());
@ -125,6 +122,13 @@ struct VectorExtractStridedSliceOpConvert final
Value srcVector = operands.front();
// Extract vector<1xT> case.
if (dstType.isa<spirv::ScalarType>()) {
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
srcVector, offset);
return success();
}
SmallVector<int32_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), offset);

View File

@ -91,8 +91,10 @@ func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
// CHECK-LABEL: func @extract_strided_slice
// CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
// CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32>
func @extract_strided_slice(%arg0: vector<4xf32>) {
%0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%1 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
spv.Return
}