[mlir][tensor] Fix bug in tensor.extract(tensor.from_elements) folder (#75109)

The folder for `tensor.extract` is not operating correctly when it is
consuming the result of a `tensor.from_elements` operation.

The existing unit test named `@extract_from_tensor.from_elements_3d` in
`mlir/test/Dialect/Tensor/canonicalize.mlir` seems an attempt to stress
this code. However, this unit tests creates a `tensor.from_elements` op
exclusively from constants, which gets folded away into a single
constant tensor. Therefore, the buggy code was never executed in unit
tests.

I have added a new unit test named
`@extract_from_tensor.from_elements_variable_3d` that makes sure the
`tensor.from_elements` op is not folded away by having its input
operands come directly from function arguments. The original folder code
would have made this test fail.

This bug was notably affecting the lowering of the `tosa.pad` op in the
`tosa-to-tensor` pass, where the generated code is likely to contain a
`tensor.from_elements` + `tensor.extract` op sequence.
This commit is contained in:
Rafael Ubal 2023-12-12 10:36:52 -05:00 committed by GitHub
parent c873f77e87
commit a8f3860bcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 2 deletions

View File

@ -1116,9 +1116,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
int flatIndex = 0;
int stride = 1;
for (int i = rank - 1; i >= 0; --i) {
if (i < rank - 1)
stride *= tensorType.getDimSize(i);
flatIndex += indices[i] * stride;
stride *= tensorType.getDimSize(i);
}
// Prevent out of bounds accesses. This can happen in invalid code that
// will never execute.

View File

@ -242,6 +242,50 @@ func.func @extract_from_tensor.from_elements_3d()
// -----
// CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32
func.func @extract_from_tensor.from_elements_variable_3d(
%f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32,
%f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32)
-> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
%tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
: tensor<3x2x2xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
%r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
%r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
%r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
%r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
%r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
%r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
%r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
%r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
%r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
%r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
%r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
: f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
}
// CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]],
// CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]]
// -----
// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
// CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
// CHECK-NEXT: return %cst : tensor<3xcomplex<i32>>