mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-23 22:00:10 +00:00
[mlir][NFC] Remove Standard dialect dependency on MemRef dialect
* Remove dependency: Standard --> MemRef * Add dependencies: GPUToNVVMTransforms --> MemRef, Linalg --> MemRef, MemRef --> Tensor * Note: The `subtensor_insert_propagate_dest_cast` test case in MemRef/canonicalize.mlir will be moved to Tensor/canonicalize.mlir in a subsequent commit, which moves over the remaining Tensor ops from the Standard dialect to the Tensor dialect. Differential Revision: https://reviews.llvm.org/D104506
This commit is contained in:
parent
80e0424b2c
commit
66f878cee9
@ -33,7 +33,8 @@ def Linalg_Dialect : Dialect {
|
||||
}];
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
let dependentDialects = [
|
||||
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
|
||||
"AffineDialect", "memref::MemRefDialect", "StandardOpsDialect",
|
||||
"tensor::TensorDialect"
|
||||
];
|
||||
let hasCanonicalizer = 1;
|
||||
let hasOperationAttrVerify = 1;
|
||||
|
@ -10,6 +10,7 @@
|
||||
#define MLIR_DIALECT_LINALG_LINALGOPS_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -10,6 +10,7 @@
|
||||
#define MLIR_DIALECT_LINALG_LINALGTYPES_H_
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -10,6 +10,7 @@
|
||||
#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Utils.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
|
@ -9,6 +9,7 @@
|
||||
#ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_
|
||||
#define MLIR_DIALECT_MEMREF_IR_MEMREF_H_
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
|
@ -19,6 +19,7 @@ def MemRef_Dialect : Dialect {
|
||||
manipulation ops, which are not strongly associated with any particular
|
||||
other dialect or domain abstraction.
|
||||
}];
|
||||
let dependentDialects = ["tensor::TensorDialect"];
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
|
||||
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -27,9 +27,6 @@ def StandardOps_Dialect : Dialect {
|
||||
let name = "std";
|
||||
let cppNamespace = "::mlir";
|
||||
let hasConstantMaterializer = 1;
|
||||
// TODO: This dependency is needed to handle memref ops in the
|
||||
// canonicalize pass and should be resolved.
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
// Base class for Standard dialect ops.
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
|
@ -43,6 +43,7 @@
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||
|
@ -217,3 +217,177 @@ func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs))
|
||||
// -> tensor.extract(%v, %idx)
|
||||
// CHECK-LABEL: func @load_from_buffer_cast(
|
||||
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
|
||||
// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
|
||||
// CHECK-NOT: memref.load
|
||||
// CHECK: return %[[RES]] : f32
|
||||
func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
|
||||
%0 = memref.buffer_cast %arg2 : memref<?x?xf32>
|
||||
%1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test case: Basic folding of memref.dim(memref.tensor_load(m)) -> memref.dim(m).
|
||||
// CHECK-LABEL: func @dim_of_tensor_load(
|
||||
// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0
|
||||
// CHECK: %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]]
|
||||
// CHECK: return %[[D]] : index
|
||||
func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
|
||||
%c0 = constant 0 : index
|
||||
%0 = memref.tensor_load %arg0 : memref<?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(tensor.generate %idx) -> %idx
|
||||
// CHECK-LABEL: func @dim_of_tensor.generate(
|
||||
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
|
||||
// CHECK-NOT: memref.dim
|
||||
// CHECK: return %[[IDX1]] : index
|
||||
func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
|
||||
%c3 = constant 3 : index
|
||||
%0 = tensor.generate %arg0, %arg1 {
|
||||
^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
|
||||
tensor.yield %c3 : index
|
||||
} : tensor<2x?x4x?x5xindex>
|
||||
%1 = memref.dim %0, %c3 : tensor<2x?x4x?x5xindex>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size
|
||||
// CHECK-LABEL: func @dim_of_alloca(
|
||||
// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index
|
||||
// CHECK-NEXT: return %[[SIZE]] : index
|
||||
func @dim_of_alloca(%size: index) -> index {
|
||||
%0 = memref.alloca(%size) : memref<?xindex>
|
||||
%c0 = constant 0 : index
|
||||
%1 = memref.dim %0, %c0 : memref<?xindex>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
|
||||
// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>
|
||||
// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
|
||||
// CHECK-NEXT: return %[[RANK]] : index
|
||||
func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
|
||||
%0 = rank %arg0 : memref<*xf32>
|
||||
%1 = memref.alloca(%0) : memref<?xindex>
|
||||
%c0 = constant 0 : index
|
||||
%2 = memref.dim %1, %c0 : memref<?xindex>
|
||||
return %2 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_memref_reshape(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
|
||||
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
|
||||
// CHECK-NEXT: %[[IDX:.*]] = constant 3
|
||||
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
|
||||
// CHECK-NEXT: memref.store
|
||||
// CHECK-NOT: memref.dim
|
||||
// CHECK: return %[[DIM]] : index
|
||||
func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
|
||||
-> index {
|
||||
%c3 = constant 3 : index
|
||||
%0 = memref.reshape %arg0(%arg1)
|
||||
: (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
// Update the shape to test that he load ends up in the right place.
|
||||
memref.store %c3, %arg1[%c3] : memref<?xindex>
|
||||
%1 = memref.dim %0, %c3 : memref<*xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_memref_reshape_i32(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
|
||||
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
|
||||
// CHECK-NEXT: %[[IDX:.*]] = constant 3
|
||||
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
|
||||
// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]]
|
||||
// CHECK-NOT: memref.dim
|
||||
// CHECK: return %[[CAST]] : index
|
||||
func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
|
||||
-> index {
|
||||
%c3 = constant 3 : index
|
||||
%0 = memref.reshape %arg0(%arg1)
|
||||
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
%1 = memref.dim %0, %c3 : memref<*xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
|
||||
// CHECK-LABEL: func @fold_dim_of_tensor.cast
|
||||
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-NEXT: return %[[C4]], %[[T0]]
|
||||
func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2: index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast_to_memref
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
|
||||
// CHECK: %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8>
|
||||
// CHECK: %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
|
||||
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
|
||||
func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
|
||||
memref<?x?x16x32xi8> {
|
||||
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
|
||||
%1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
|
||||
return %1 : memref<?x?x16x32xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Move this test to Tensor/canonicalize.mlir.
|
||||
func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
|
||||
%arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c8 = constant 8 : index
|
||||
%0 = memref.dim %arg0, %c1 : tensor<2x?xi32>
|
||||
%1 = tensor.extract %arg1[] : tensor<i32>
|
||||
%2 = tensor.generate %arg2, %c8 {
|
||||
^bb0(%arg4: index, %arg5: index):
|
||||
tensor.yield %1 : i32
|
||||
} : tensor<?x?xi32>
|
||||
%3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
|
||||
return %3 : tensor<?x?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
|
||||
// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
|
||||
// CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
|
||||
// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]
|
||||
// CHECK: return %[[CAST]]
|
||||
|
@ -1,53 +1,5 @@
|
||||
// RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s
|
||||
|
||||
// Test case: Basic folding of memref.dim(memref.tensor_load(m)) -> memref.dim(m).
|
||||
// CHECK-LABEL: func @dim_of_tensor_load(
|
||||
// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0
|
||||
// CHECK: %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]]
|
||||
// CHECK: return %[[D]] : index
|
||||
func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
|
||||
%c0 = constant 0 : index
|
||||
%0 = memref.tensor_load %arg0 : memref<?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs))
|
||||
// -> tensor.extract(%v, %idx)
|
||||
// CHECK-LABEL: func @load_from_buffer_cast(
|
||||
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
|
||||
// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
|
||||
// CHECK-NOT: memref.load
|
||||
// CHECK: return %[[RES]] : f32
|
||||
func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
|
||||
%0 = memref.buffer_cast %arg2 : memref<?x?xf32>
|
||||
%1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(tensor.generate %idx) -> %idx
|
||||
// CHECK-LABEL: func @dim_of_tensor.generate(
|
||||
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
|
||||
// CHECK-NOT: memref.dim
|
||||
// CHECK: return %[[IDX1]] : index
|
||||
func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
|
||||
%c3 = constant 3 : index
|
||||
%0 = tensor.generate %arg0, %arg1 {
|
||||
^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
|
||||
tensor.yield %c3 : index
|
||||
} : tensor<2x?x4x?x5xindex>
|
||||
%1 = memref.dim %0, %c3 : tensor<2x?x4x?x5xindex>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of comparisons with equal operands.
|
||||
// CHECK-LABEL: @cmpi_equal_operands
|
||||
// CHECK-DAG: %[[T:.*]] = constant true
|
||||
@ -72,108 +24,6 @@ func @cmpi_equal_operands(%arg0: i64)
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size
|
||||
// CHECK-LABEL: func @dim_of_alloca(
|
||||
// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index
|
||||
// CHECK-NEXT: return %[[SIZE]] : index
|
||||
func @dim_of_alloca(%size: index) -> index {
|
||||
%0 = memref.alloca(%size) : memref<?xindex>
|
||||
%c0 = constant 0 : index
|
||||
%1 = memref.dim %0, %c0 : memref<?xindex>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
|
||||
// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>
|
||||
// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
|
||||
// CHECK-NEXT: return %[[RANK]] : index
|
||||
func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
|
||||
%0 = rank %arg0 : memref<*xf32>
|
||||
%1 = memref.alloca(%0) : memref<?xindex>
|
||||
%c0 = constant 0 : index
|
||||
%2 = memref.dim %1, %c0 : memref<?xindex>
|
||||
return %2 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_memref_reshape(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
|
||||
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
|
||||
// CHECK-NEXT: %[[IDX:.*]] = constant 3
|
||||
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
|
||||
// CHECK-NEXT: memref.store
|
||||
// CHECK-NOT: memref.dim
|
||||
// CHECK: return %[[DIM]] : index
|
||||
func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
|
||||
-> index {
|
||||
%c3 = constant 3 : index
|
||||
%0 = memref.reshape %arg0(%arg1)
|
||||
: (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
// Update the shape to test that he load ends up in the right place.
|
||||
memref.store %c3, %arg1[%c3] : memref<?xindex>
|
||||
%1 = memref.dim %0, %c3 : memref<*xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_memref_reshape_i32(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
|
||||
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
|
||||
// CHECK-NEXT: %[[IDX:.*]] = constant 3
|
||||
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
|
||||
// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]]
|
||||
// CHECK-NOT: memref.dim
|
||||
// CHECK: return %[[CAST]] : index
|
||||
func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
|
||||
-> index {
|
||||
%c3 = constant 3 : index
|
||||
%0 = memref.reshape %arg0(%arg1)
|
||||
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
%1 = memref.dim %0, %c3 : memref<*xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
|
||||
// CHECK-LABEL: func @fold_dim_of_tensor.cast
|
||||
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-NEXT: return %[[C4]], %[[T0]]
|
||||
func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2: index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast_to_memref
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
|
||||
// CHECK: %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8>
|
||||
// CHECK: %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
|
||||
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
|
||||
func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
|
||||
memref<?x?x16x32xi8> {
|
||||
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
|
||||
%1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
|
||||
return %1 : memref<?x?x16x32xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
|
||||
%arg2 : index) -> tensor<?x?x?xf32>
|
||||
{
|
||||
@ -345,29 +195,6 @@ func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x
|
||||
|
||||
// -----
|
||||
|
||||
func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
|
||||
%arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c8 = constant 8 : index
|
||||
%0 = memref.dim %arg0, %c1 : tensor<2x?xi32>
|
||||
%1 = tensor.extract %arg1[] : tensor<i32>
|
||||
%2 = tensor.generate %arg2, %c8 {
|
||||
^bb0(%arg4: index, %arg5: index):
|
||||
tensor.yield %1 : i32
|
||||
} : tensor<?x?xi32>
|
||||
%3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
|
||||
return %3 : tensor<?x?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
|
||||
// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
|
||||
// CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
|
||||
// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]
|
||||
// CHECK: return %[[CAST]]
|
||||
|
||||
// -----
|
||||
|
||||
func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
|
Loading…
Reference in New Issue
Block a user