mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-01 01:31:26 +00:00
[mlir] Move std.tensor_cast
-> tensor.cast
.
This is almost entirely mechanical. Differential Revision: https://reviews.llvm.org/D93357
This commit is contained in:
parent
a555ca8b3d
commit
129d6e554e
@ -34,7 +34,7 @@ namespace linalg {
|
||||
class LinalgDependenceGraph;
|
||||
|
||||
/// A struct containing the Linalg producer before and after fusion.
|
||||
/// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op
|
||||
/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
|
||||
/// before the consumer Linalg op, until enough canonicalizations have applied.
|
||||
struct FusionInfo {
|
||||
LinalgOp originalProducer;
|
||||
|
@ -354,31 +354,6 @@ computeRankReductionMask(ArrayRef<int64_t> originalShape,
|
||||
/// ```
|
||||
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
|
||||
|
||||
/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
|
||||
/// Determines whether TensorCastOp casts to a more dynamic version of the
|
||||
/// source tensor. This is useful to fold a tensor_cast into a consuming op and
|
||||
/// implement canonicalization patterns for ops in different dialects that may
|
||||
/// consume the results of tensor_cast operations. Such foldable tensor_cast
|
||||
/// operations are typically inserted as `subtensor` ops and are canonicalized,
|
||||
/// to preserve the type compatibility of their uses.
|
||||
///
|
||||
/// Returns true when all conditions are met:
|
||||
/// 1. source and result are ranked tensors with same element type and rank.
|
||||
/// 2. the tensor type has more static information than the result
|
||||
///
|
||||
/// Example:
|
||||
/// ```mlir
|
||||
/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
||||
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// folds into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
|
||||
/// ```
|
||||
bool canFoldIntoConsumerOp(TensorCastOp castOp);
|
||||
|
||||
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
||||
/// comparison predicates.
|
||||
bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
||||
|
@ -62,7 +62,7 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
let printer = [{
|
||||
return printStandardCastOp(this->getOperation(), p);
|
||||
}];
|
||||
let verifier = [{ return ::verifyCastOp(*this); }];
|
||||
let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
@ -3428,56 +3428,6 @@ def TanhOp : FloatUnaryOp<"tanh"> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TensorCastOp : CastOp<"tensor_cast"> {
|
||||
let summary = "tensor cast operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `std.tensor_cast` ssa-use `:` type `to` type
|
||||
```
|
||||
|
||||
Convert a tensor from one type to an equivalent type without changing any
|
||||
data elements. The source and destination types must both be tensor types
|
||||
with the same element type. If both are ranked, then the rank should be the
|
||||
same and static dimensions should match. The operation is invalid if
|
||||
converting to a mismatching constant dimension.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Convert from unknown rank to rank 2 with unknown dimension sizes.
|
||||
%2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
|
||||
%2 = tensor_cast %1 : tensor<*xf32> to tensor<?x?xf32>
|
||||
|
||||
// Convert to a type with more known dimensions.
|
||||
%3 = "std.tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
|
||||
|
||||
// Discard static dimension and rank information.
|
||||
%4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
|
||||
%5 = "std.tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$source);
|
||||
let results = (outs AnyTensor);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a tensor_cast is always a tensor.
|
||||
TensorType getType() { return getResult().getType().cast<TensorType>(); }
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -28,4 +28,38 @@
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor Dialect Helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace tensor {
|
||||
|
||||
/// Determines whether tensor::CastOp casts to a more dynamic version of the
|
||||
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
|
||||
/// implement canonicalization patterns for ops in different dialects that may
|
||||
/// consume the results of tensor.cast operations. Such foldable tensor.cast
|
||||
/// operations are typically inserted as `subtensor` ops and are canonicalized,
|
||||
/// to preserve the type compatibility of their uses.
|
||||
///
|
||||
/// Returns true when all conditions are met:
|
||||
/// 1. source and result are ranked tensors with same element type and rank.
|
||||
/// 2. the tensor type has more static information than the result
|
||||
///
|
||||
/// Example:
|
||||
/// ```mlir
|
||||
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
||||
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// folds into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
|
||||
/// ```
|
||||
bool canFoldIntoConsumerOp(CastOp castOp);
|
||||
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
|
||||
|
@ -19,6 +19,52 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
|
||||
let summary = "tensor cast operation";
|
||||
let description = [{
|
||||
Convert a tensor from one type to an equivalent type without changing any
|
||||
data elements. The source and destination types must both be tensor types
|
||||
with the same element type. If both are ranked, then the rank should be the
|
||||
same and static dimensions should match. The operation is invalid if
|
||||
converting to a mismatching constant dimension.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Convert from unknown rank to rank 2 with unknown dimension sizes.
|
||||
%2 = tensor.cast %1 : tensor<*xf32> to tensor<?x?xf32>
|
||||
|
||||
// Convert to a type with more known dimensions.
|
||||
%3 = tensor.cast %2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
|
||||
// Discard static dimension and rank information.
|
||||
%4 = tensor.cast %3 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%5 = tensor.cast %4 : tensor<?x?xf32> to tensor<*xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$source);
|
||||
let results = (outs AnyTensor:$dest);
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
|
||||
let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a tensor.cast is always a tensor.
|
||||
TensorType getType() { return getResult().getType().cast<TensorType>(); }
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1775,11 +1775,18 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p);
|
||||
// These functions are out-of-line implementations of the methods in CastOp,
|
||||
// which avoids them being template instantiated/duplicated.
|
||||
namespace impl {
|
||||
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
|
||||
// need for them, but some older ODS code in `std` still depends on them).
|
||||
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
|
||||
Type destType);
|
||||
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
|
||||
void printCastOp(Operation *op, OpAsmPrinter &p);
|
||||
// TODO: Create a CastOpInterface with a method areCastCompatible.
|
||||
// Also, consider adding functionality to CastOpInterface to be able to perform
|
||||
// the ChainedTensorCast canonicalization generically.
|
||||
Value foldCastOp(Operation *op);
|
||||
LogicalResult verifyCastOp(Operation *op,
|
||||
function_ref<bool(Type, Type)> areCastCompatible);
|
||||
} // namespace impl
|
||||
} // end namespace mlir
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -tensor-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
|
||||
// RUN: | FileCheck %s
|
||||
@ -8,7 +8,7 @@ func @main() {
|
||||
%b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
|
||||
|
||||
%addf = addf %a, %b : tensor<3xf32>
|
||||
%addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
|
||||
%addf_unranked = tensor.cast %addf : tensor<3xf32> to tensor<*xf32>
|
||||
call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
|
||||
// CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
|
||||
// CHECK-NEXT: [11, 22, 33]
|
||||
|
@ -1,4 +1,6 @@
|
||||
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
|
||||
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
|
||||
// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
|
||||
// RUN: -finalizing-bufferize \
|
||||
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
|
||||
@ -15,14 +17,14 @@ func @main() {
|
||||
%inserted_at_position_0 = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
|
||||
%inserted_at_position_1 = subtensor_insert %insert_val into %const[1][1][1] : tensor<1xf32> into tensor<2xf32>
|
||||
|
||||
%unranked_at_position_0 = tensor_cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
|
||||
%unranked_at_position_0 = tensor.cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
|
||||
call @print_memref_f32(%unranked_at_position_0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
|
||||
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
|
||||
// CHECK-NEXT: [20, 10]
|
||||
|
||||
%unranked_at_position_1 = tensor_cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
|
||||
%unranked_at_position_1 = tensor.cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
|
||||
call @print_memref_f32(%unranked_at_position_1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
|
||||
|
@ -1,4 +1,6 @@
|
||||
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
|
||||
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
|
||||
// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
|
||||
// RUN: -finalizing-bufferize \
|
||||
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
|
||||
@ -9,7 +11,7 @@ func @main() {
|
||||
%insert_val = constant dense<20.0> : tensor<1xf32>
|
||||
%inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
|
||||
|
||||
%unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32>
|
||||
%unranked = tensor.cast %inserted : tensor<2xf32> to tensor<*xf32>
|
||||
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
|
||||
|
@ -1,5 +1,5 @@
|
||||
// RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize \
|
||||
// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
|
||||
// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
|
||||
// RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
|
||||
@ -19,7 +19,7 @@ func @main() {
|
||||
// Note that this is skipping a step and we would need at least some function
|
||||
// attribute to declare that this conversion is valid (e.g. when we statically
|
||||
// know that things will play nicely at the C ABI boundary).
|
||||
%unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
|
||||
%unranked = tensor.cast %0 : tensor<4xf32> to tensor<*xf32>
|
||||
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
|
||||
|
@ -1,12 +1,13 @@
|
||||
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize \
|
||||
// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
|
||||
// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
|
||||
// RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,2,3" -linalg-bufferize \
|
||||
// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
|
||||
// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -tensor-bufferize \
|
||||
// RUN: -func-bufferize \
|
||||
// RUN: -finalizing-bufferize -convert-linalg-to-loops -convert-scf-to-std \
|
||||
// RUN: -convert-linalg-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
|
||||
@ -23,7 +24,7 @@ func @main() {
|
||||
%D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
|
||||
init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
|
||||
%unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32>
|
||||
%unranked = tensor.cast %D : tensor<2x4xf32> to tensor<*xf32>
|
||||
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
|
||||
|
@ -103,9 +103,9 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
|
||||
auto erasedRankType =
|
||||
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
Value rankErasedLhs =
|
||||
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
|
||||
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
|
||||
Value rankErasedRhs =
|
||||
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
|
||||
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
|
||||
Value lesserRankOperand =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
|
||||
Value greaterRankOperand =
|
||||
@ -186,7 +186,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
|
||||
Value tensor =
|
||||
rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
|
||||
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -246,9 +246,9 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
|
||||
auto erasedRankType =
|
||||
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
Value rankErasedLhs =
|
||||
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
|
||||
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
|
||||
Value rankErasedRhs =
|
||||
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
|
||||
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
|
||||
Value lesserRankOperand =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
|
||||
Value greaterRankOperand =
|
||||
@ -528,8 +528,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
|
||||
// Materialize extent tensor.
|
||||
Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
|
||||
loc, rewriter.getIndexType(), extentValues);
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
|
||||
op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
|
||||
staticExtentTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -561,8 +561,8 @@ public:
|
||||
if (!adaptor.input().getType().isa<RankedTensorType>())
|
||||
return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
|
||||
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
|
||||
op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
|
||||
adaptor.input());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -16,4 +16,5 @@ add_mlir_dialect_library(MLIRLinalg
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRViewLikeInterface
|
||||
MLIRStandard
|
||||
MLIRTensor
|
||||
)
|
||||
|
@ -651,7 +651,7 @@ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
|
||||
auto newOp =
|
||||
rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
|
||||
rewriter.getI64ArrayAttr(staticSizes));
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, op.getType(), newOp);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1815,12 +1815,12 @@ struct FoldTensorCastOp : public RewritePattern {
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
|
||||
// If no operand comes from a TensorCastOp and can be folded then fail.
|
||||
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
||||
bool hasTensorCastOperand =
|
||||
llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
|
||||
if (v.isa<BlockArgument>())
|
||||
return false;
|
||||
auto castOp = v.getDefiningOp<TensorCastOp>();
|
||||
auto castOp = v.getDefiningOp<tensor::CastOp>();
|
||||
return castOp && canFoldIntoConsumerOp(castOp);
|
||||
});
|
||||
if (!hasTensorCastOperand)
|
||||
@ -1832,7 +1832,7 @@ struct FoldTensorCastOp : public RewritePattern {
|
||||
newOperands.reserve(op->getNumOperands());
|
||||
// Inputs may fold.
|
||||
for (Value v : linalgOp.getInputs()) {
|
||||
auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
|
||||
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
|
||||
newOperands.push_back(
|
||||
canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
|
||||
}
|
||||
@ -1841,7 +1841,7 @@ struct FoldTensorCastOp : public RewritePattern {
|
||||
linalgOp.getOutputBuffers().end());
|
||||
// Init tensors may fold, in which case the resultType must also change.
|
||||
for (Value v : linalgOp.getInitTensors()) {
|
||||
auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
|
||||
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
|
||||
bool fold = canFoldIntoConsumerOp(tensorCastOp);
|
||||
newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
|
||||
newResultTypes.push_back(newOperands.back().getType());
|
||||
|
@ -59,6 +59,7 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
|
||||
|
||||
void mlir::linalg::LinalgDialect::initialize() {
|
||||
getContext()->getOrLoadDialect("std");
|
||||
getContext()->getOrLoadDialect("tensor");
|
||||
|
||||
addTypes<RangeType>();
|
||||
addOperations<
|
||||
|
@ -36,6 +36,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRStandardToLLVM
|
||||
MLIRTensor
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
MLIRVector
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
@ -517,13 +518,13 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
|
||||
// Replace use.
|
||||
// Canonicalizations are not guaranteed to have happened before constructing
|
||||
// `fusedProducer`. In the tensor case this can result in temporary type
|
||||
// mismatches. Insert a `tensor_cast` op to propagate the transformation
|
||||
// mismatches. Insert a `tensor.cast` op to propagate the transformation
|
||||
// invariant that types are compatible.
|
||||
Value def = fusedProducer->getResult(producerIdx);
|
||||
OpOperand &use = consumer->getOpOperand(consumerIdx);
|
||||
Type consumerType = use.get().getType();
|
||||
if (consumerType != def.getType())
|
||||
def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
|
||||
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
|
||||
use.set(def);
|
||||
return FusionInfo{producerOp, fusedProducer};
|
||||
}
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/SCF/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
@ -569,7 +570,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
|
||||
ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
SubViewOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
TensorCastOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
ViewOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
CanonicalizationPatternList<
|
||||
#define GET_OP_LIST
|
||||
|
@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRShape
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRStandard
|
||||
MLIRTensor
|
||||
)
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -1,5 +1,6 @@
|
||||
include "mlir/Dialect/Shape/IR/ShapeOps.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "mlir/Dialect/Tensor/IR/TensorOps.td"
|
||||
|
||||
def AllInputShapesEq : Constraint<CPred< [{
|
||||
llvm::all_of($0, [&](mlir::Value val) {
|
||||
@ -32,7 +33,7 @@ def SizeToIndexToSizeCanonicalization : Pat<
|
||||
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
// Fold tensor_cast(const_shape) to const_shape. This changes the type of
|
||||
// Fold tensor.cast(const_shape) to const_shape. This changes the type of
|
||||
// const_shape to the destination type of the cast.
|
||||
def TensorCastConstShape : Pat <
|
||||
(TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
|
||||
(Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
|
||||
|
@ -141,18 +141,6 @@ static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
|
||||
<< op->getResult(0).getType();
|
||||
}
|
||||
|
||||
/// A custom cast operation verifier.
|
||||
template <typename T>
|
||||
static LogicalResult verifyCastOp(T op) {
|
||||
auto opType = op.getOperand().getType();
|
||||
auto resType = op.getType();
|
||||
if (!T::areCastCompatible(opType, resType))
|
||||
return op.emitError("operand type ") << opType << " and result type "
|
||||
<< resType << " are cast incompatible";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void StandardOpsDialect::initialize() {
|
||||
getContext()->loadDialect<tensor::TensorDialect>();
|
||||
addOperations<DmaStartOp, DmaWaitOp,
|
||||
@ -1494,7 +1482,7 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
|
||||
|
||||
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<DimOfMemRefReshape, DimOfCastOp<TensorCastOp>>(context);
|
||||
results.insert<DimOfMemRefReshape, DimOfCastOp<tensor::CastOp>>(context);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@ -1870,8 +1858,8 @@ struct StaticDynamicTensorFromElements
|
||||
newOperands);
|
||||
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
|
||||
newOp.body().begin());
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
|
||||
newOp);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
|
||||
newOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1913,7 +1901,7 @@ struct ExtractFromDynamicTensorFromElements
|
||||
|
||||
/// Canonicalizes the pattern of the form
|
||||
///
|
||||
/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
|
||||
/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
|
||||
/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
|
||||
///
|
||||
/// to
|
||||
@ -1924,7 +1912,7 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorCast = extract.tensor().getDefiningOp<TensorCastOp>();
|
||||
auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
|
||||
if (!tensorCast)
|
||||
return failure();
|
||||
|
||||
@ -3395,7 +3383,7 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
|
||||
|
||||
static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
|
||||
SubTensorOp newOp) {
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, newOp, op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
||||
}
|
||||
|
||||
/// Pattern to rewrite a subview op with constant arguments.
|
||||
@ -3536,60 +3524,6 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
|
||||
/// Determines whether TensorCastOp casts to a more dynamic version of the
|
||||
/// source tensor. This is useful to fold a tensor_cast into a consuming op and
|
||||
/// implement canonicalization patterns for ops in different dialects that may
|
||||
/// consume the results of tensor_cast operations. Such foldable tensor_cast
|
||||
/// operations are typically inserted as `subtensor` ops and are canonicalized,
|
||||
/// to preserve the type compatibility of their uses.
|
||||
///
|
||||
/// Returns true when all conditions are met:
|
||||
/// 1. source and result are ranked tensors with same element type and rank.
|
||||
/// 2. the tensor type has more static information than the result
|
||||
///
|
||||
/// Example:
|
||||
/// ```mlir
|
||||
/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
||||
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// folds into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
|
||||
/// ```
|
||||
bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
|
||||
if (!castOp)
|
||||
return false;
|
||||
|
||||
RankedTensorType sourceType =
|
||||
castOp.source().getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
// Requires RankedTensorType.
|
||||
if (!sourceType || !resultType)
|
||||
return false;
|
||||
|
||||
// Requires same elemental type.
|
||||
if (sourceType.getElementType() != resultType.getElementType())
|
||||
return false;
|
||||
|
||||
// Requires same rank.
|
||||
if (sourceType.getRank() != resultType.getRank())
|
||||
return false;
|
||||
|
||||
// If cast is towards more static sizes along any dimension, don't fold.
|
||||
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
||||
auto ss = std::get<0>(it), st = std::get<1>(it);
|
||||
if (ss != st)
|
||||
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to rewrite a subview op with MemRefCast arguments.
|
||||
/// This essentially pushes memref_cast past its consuming subview when
|
||||
@ -3857,107 +3791,6 @@ static LogicalResult verify(SubTensorInsertOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool TensorCastOp::areCastCompatible(Type a, Type b) {
|
||||
auto aT = a.dyn_cast<TensorType>();
|
||||
auto bT = b.dyn_cast<TensorType>();
|
||||
if (!aT || !bT)
|
||||
return false;
|
||||
|
||||
if (aT.getElementType() != bT.getElementType())
|
||||
return false;
|
||||
|
||||
return succeeded(verifyCompatibleShape(aT, bT));
|
||||
}
|
||||
|
||||
OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// Compute a TensorType that has the joined shape knowledge of the two
|
||||
/// given TensorTypes. The element types need to match.
|
||||
static TensorType joinShapes(TensorType one, TensorType two) {
|
||||
assert(one.getElementType() == two.getElementType());
|
||||
|
||||
if (!one.hasRank())
|
||||
return two;
|
||||
if (!two.hasRank())
|
||||
return one;
|
||||
|
||||
int64_t rank = one.getRank();
|
||||
if (rank != two.getRank())
|
||||
return {};
|
||||
|
||||
SmallVector<int64_t, 4> join;
|
||||
join.reserve(rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
if (one.isDynamicDim(i)) {
|
||||
join.push_back(two.getDimSize(i));
|
||||
continue;
|
||||
}
|
||||
if (two.isDynamicDim(i)) {
|
||||
join.push_back(one.getDimSize(i));
|
||||
continue;
|
||||
}
|
||||
if (one.getDimSize(i) != two.getDimSize(i))
|
||||
return {};
|
||||
join.push_back(one.getDimSize(i));
|
||||
}
|
||||
return RankedTensorType::get(join, one.getElementType());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Replaces chains of two tensor_cast operations by a single tensor_cast
|
||||
/// operation if doing so does not remove runtime constraints.
|
||||
struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> {
|
||||
using OpRewritePattern<TensorCastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TensorCastOp tensorCast,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorCastOperand =
|
||||
tensorCast.getOperand().getDefiningOp<TensorCastOp>();
|
||||
|
||||
if (!tensorCastOperand)
|
||||
return failure();
|
||||
|
||||
auto sourceType =
|
||||
tensorCastOperand.getOperand().getType().cast<TensorType>();
|
||||
auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
|
||||
auto resultType = tensorCast.getType().cast<TensorType>();
|
||||
|
||||
// We can remove the intermediate cast if joining all three produces the
|
||||
// same result as just joining the source and result shapes.
|
||||
auto firstJoin =
|
||||
joinShapes(joinShapes(sourceType, intermediateType), resultType);
|
||||
|
||||
// The join might not exist if the cast sequence would fail at runtime.
|
||||
if (!firstJoin)
|
||||
return failure();
|
||||
|
||||
// The newJoin always exists if the above join exists, it might just contain
|
||||
// less information. If so, we cannot drop the intermediate cast, as doing
|
||||
// so would remove runtime checks.
|
||||
auto newJoin = joinShapes(sourceType, resultType);
|
||||
if (firstJoin != newJoin)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType,
|
||||
tensorCastOperand.getOperand());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void TensorCastOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ChainedTensorCast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -117,20 +117,6 @@ public:
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeTensorFromElementsOp
|
||||
: public OpConversionPattern<TensorFromElementsOp> {
|
||||
@ -162,7 +148,6 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context,
|
||||
BufferizeDimOp,
|
||||
BufferizeDynamicTensorFromElementsOp,
|
||||
BufferizeSelectOp,
|
||||
BufferizeTensorCastOp,
|
||||
BufferizeTensorFromElementsOp
|
||||
// clang-format on
|
||||
>(typeConverter, context);
|
||||
@ -180,8 +165,7 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
|
||||
target.addLegalDialect<scf::SCFDialect>();
|
||||
|
||||
populateStdBufferizePatterns(context, typeConverter, patterns);
|
||||
target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
|
||||
TensorFromElementsOp>();
|
||||
target.addIllegalOp<DynamicTensorFromElementsOp, TensorFromElementsOp>();
|
||||
// We only bufferize the case of tensor selected type and scalar condition,
|
||||
// as that boils down to a select over memref descriptors (don't need to
|
||||
// touch the data).
|
||||
|
@ -8,12 +8,165 @@
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tensor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Determines whether tensor::CastOp casts to a more dynamic version of the
|
||||
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
|
||||
/// implement canonicalization patterns for ops in different dialects that may
|
||||
/// consume the results of tensor.cast operations. Such foldable tensor.cast
|
||||
/// operations are typically inserted as `subtensor` ops and are canonicalized,
|
||||
/// to preserve the type compatibility of their uses.
|
||||
///
|
||||
/// Returns true when all conditions are met:
|
||||
/// 1. source and result are ranked tensors with same element type and rank.
|
||||
/// 2. the tensor type has more static information than the result
|
||||
///
|
||||
/// Example:
|
||||
/// ```mlir
|
||||
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
||||
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// folds into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
|
||||
/// ```
|
||||
bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
||||
if (!castOp)
|
||||
return false;
|
||||
|
||||
RankedTensorType sourceType =
|
||||
castOp.source().getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
// Requires RankedTensorType.
|
||||
if (!sourceType || !resultType)
|
||||
return false;
|
||||
|
||||
// Requires same elemental type.
|
||||
if (sourceType.getElementType() != resultType.getElementType())
|
||||
return false;
|
||||
|
||||
// Requires same rank.
|
||||
if (sourceType.getRank() != resultType.getRank())
|
||||
return false;
|
||||
|
||||
// If cast is towards more static sizes along any dimension, don't fold.
|
||||
for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
||||
if (ShapedType::isDynamic(std::get<0>(t)) &&
|
||||
!ShapedType::isDynamic(std::get<1>(t)))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CastOp::areCastCompatible(Type a, Type b) {
|
||||
auto aT = a.dyn_cast<TensorType>();
|
||||
auto bT = b.dyn_cast<TensorType>();
|
||||
if (!aT || !bT)
|
||||
return false;
|
||||
|
||||
if (aT.getElementType() != bT.getElementType())
|
||||
return false;
|
||||
|
||||
return succeeded(verifyCompatibleShape(aT, bT));
|
||||
}
|
||||
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// Compute a TensorType that has the joined shape knowledge of the two
|
||||
/// given TensorTypes. The element types need to match.
|
||||
static TensorType joinShapes(TensorType one, TensorType two) {
|
||||
assert(one.getElementType() == two.getElementType());
|
||||
|
||||
if (!one.hasRank())
|
||||
return two;
|
||||
if (!two.hasRank())
|
||||
return one;
|
||||
|
||||
int64_t rank = one.getRank();
|
||||
if (rank != two.getRank())
|
||||
return {};
|
||||
|
||||
SmallVector<int64_t, 4> join;
|
||||
join.reserve(rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
if (one.isDynamicDim(i)) {
|
||||
join.push_back(two.getDimSize(i));
|
||||
continue;
|
||||
}
|
||||
if (two.isDynamicDim(i)) {
|
||||
join.push_back(one.getDimSize(i));
|
||||
continue;
|
||||
}
|
||||
if (one.getDimSize(i) != two.getDimSize(i))
|
||||
return {};
|
||||
join.push_back(one.getDimSize(i));
|
||||
}
|
||||
return RankedTensorType::get(join, one.getElementType());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Replaces chains of two tensor.cast operations by a single tensor.cast
|
||||
/// operation if doing so does not remove runtime constraints.
|
||||
struct ChainedTensorCast : public OpRewritePattern<CastOp> {
|
||||
using OpRewritePattern<CastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CastOp tensorCast,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
|
||||
|
||||
if (!tensorCastOperand)
|
||||
return failure();
|
||||
|
||||
auto sourceType =
|
||||
tensorCastOperand.getOperand().getType().cast<TensorType>();
|
||||
auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
|
||||
auto resultType = tensorCast.getType().cast<TensorType>();
|
||||
|
||||
// We can remove the intermediate cast if joining all three produces the
|
||||
// same result as just joining the source and result shapes.
|
||||
auto firstJoin =
|
||||
joinShapes(joinShapes(sourceType, intermediateType), resultType);
|
||||
|
||||
// The join might not exist if the cast sequence would fail at runtime.
|
||||
if (!firstJoin)
|
||||
return failure();
|
||||
|
||||
// The newJoin always exists if the above join exists, it might just contain
|
||||
// less information. If so, we cannot drop the intermediate cast, as doing
|
||||
// so would remove runtime checks.
|
||||
auto newJoin = joinShapes(sourceType, resultType);
|
||||
if (firstJoin != newJoin)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
|
||||
tensorCastOperand.getOperand());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ChainedTensorCast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -19,6 +19,20 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
||||
public:
|
||||
@ -37,7 +51,7 @@ public:
|
||||
void mlir::populateTensorBufferizePatterns(
|
||||
MLIRContext *context, BufferizeTypeConverter &typeConverter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<BufferizeExtractOp>(typeConverter, context);
|
||||
patterns.insert<BufferizeCastOp, BufferizeExtractOp>(typeConverter, context);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -49,7 +63,7 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
||||
ConversionTarget target(*context);
|
||||
|
||||
populateTensorBufferizePatterns(context, typeConverter, patterns);
|
||||
target.addIllegalOp<tensor::ExtractOp>();
|
||||
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
|
||||
if (failed(
|
||||
|
@ -1213,6 +1213,19 @@ Value impl::foldCastOp(Operation *op) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
impl::verifyCastOp(Operation *op,
|
||||
function_ref<bool(Type, Type)> areCastCompatible) {
|
||||
auto opType = op->getOperand(0).getType();
|
||||
auto resType = op->getResult(0).getType();
|
||||
if (!areCastCompatible(opType, resType))
|
||||
return op->emitError("operand type ")
|
||||
<< opType << " and result type " << resType
|
||||
<< " are cast incompatible";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Misc. utils
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -95,7 +95,7 @@ func @const_shape() -> tensor<?xindex> {
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]]
|
||||
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: return %[[RESULT]] : tensor<?xindex>
|
||||
%shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
return %shape : tensor<?xindex>
|
||||
@ -108,7 +108,7 @@ func @const_shape() -> tensor<?xindex> {
|
||||
// CHECK-SAME: () -> tensor<?xindex>
|
||||
func @const_shape_zero_elements() -> tensor<?xindex> {
|
||||
// CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
|
||||
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
|
||||
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
|
||||
// CHECK: return %[[RESULT]] : tensor<?xindex>
|
||||
%shape = shape.const_shape [] : tensor<?xindex>
|
||||
return %shape : tensor<?xindex>
|
||||
@ -152,13 +152,13 @@ func @const_size() -> index {
|
||||
|
||||
// -----
|
||||
|
||||
// Lower `to_extent_tensor` to `std.tensor_cast`
|
||||
// Lower `to_extent_tensor` to `tensor.cast`
|
||||
// Fold to_extent_tensor when already on tensor.
|
||||
// CHECK-LABEL: @to_extent_tensor
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex>
|
||||
func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
|
||||
// CHECK-NOT: to_extent_tensor
|
||||
// CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
|
||||
// CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
|
||||
%casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex>
|
||||
// CHECK: return %[[RES]]
|
||||
return %casted : tensor<3xindex>
|
||||
@ -316,8 +316,8 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
@ -356,8 +356,8 @@ func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xinde
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
@ -400,8 +400,8 @@ func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
|
||||
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
|
||||
@ -438,8 +438,8 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
|
||||
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
|
||||
|
@ -317,20 +317,20 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast(
|
||||
func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
|
||||
// CHECK-LABEL: func @tensor.cast(
|
||||
func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
|
||||
-> tensor<3x?xf32>
|
||||
{
|
||||
%ta = tensor_cast %a : tensor<3x4xf32> to tensor<?x?xf32>
|
||||
%tb = tensor_cast %b : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%tc = tensor_cast %c : tensor<3x?xf32> to tensor<?x?xf32>
|
||||
%ta = tensor.cast %a : tensor<3x4xf32> to tensor<?x?xf32>
|
||||
%tb = tensor.cast %b : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32>
|
||||
|
||||
// CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
|
||||
// CHECK-SAME: init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
%0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
init(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
%1 = tensor_cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
|
||||
%1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
|
||||
|
||||
return %1: tensor<3x?xf32>
|
||||
}
|
||||
@ -360,7 +360,7 @@ func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
|
||||
}
|
||||
// CHECK: func @init_tensor_canonicalize
|
||||
// CHECK: %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32>
|
||||
// CHECK: %[[T1:.+]] = tensor_cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
|
||||
// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
|
||||
// CHECK: return %[[T1]]
|
||||
|
||||
// -----
|
||||
|
@ -872,24 +872,24 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
|
||||
|
||||
// -----
|
||||
|
||||
// Verify that tensor_cast folding uses the correct type
|
||||
// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned
|
||||
func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
|
||||
// Verify that tensor.cast folding uses the correct type
|
||||
// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned
|
||||
func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
|
||||
// CHECK: constant dense<2> : tensor<1xindex>
|
||||
// CHECK-NOT: tensor_cast
|
||||
// CHECK-NOT: tensor.cast
|
||||
%0 = shape.const_shape [2] : tensor<?xindex>
|
||||
%1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
|
||||
%1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
|
||||
return %1 : tensor<1xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify that tensor_cast folding uses the correct type
|
||||
// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic
|
||||
func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
|
||||
// Verify that tensor.cast folding uses the correct type
|
||||
// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic
|
||||
func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
|
||||
// CHECK: shape.const_shape [2] : tensor<?xindex>
|
||||
// CHECK-NOT: tensor_cast
|
||||
// CHECK-NOT: tensor.cast
|
||||
%0 = shape.const_shape [2] : tensor<1xindex>
|
||||
%1 = tensor_cast %0 : tensor<1xindex> to tensor<?xindex>
|
||||
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
|
||||
return %1 : tensor<?xindex>
|
||||
}
|
||||
|
@ -75,39 +75,6 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
|
||||
// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
|
||||
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]]
|
||||
// CHECK: return %[[RET]] : tensor<2xindex>
|
||||
func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
|
||||
%0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
|
||||
return %0 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast_from_unranked(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
|
||||
// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
|
||||
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
|
||||
// CHECK: return %[[RET]] : tensor<2xf32>
|
||||
func @tensor_cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
|
||||
%0 = tensor_cast %arg0 : tensor<*xf32> to tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast_to_unranked(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
|
||||
// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
|
||||
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
|
||||
// CHECK: return %[[RET]] : tensor<*xf32>
|
||||
func @tensor_cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
|
||||
%0 = tensor_cast %arg0 : tensor<2xf32> to tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_from_elements(
|
||||
// CHECK-SAME: %[[ELEM0:.*]]: index,
|
||||
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
|
||||
|
@ -116,17 +116,17 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx
|
||||
// CHECK-LABEL: func @fold_dim_of_tensor_cast
|
||||
// Test case: Folding dim(tensor.cast %0, %idx) -> dim %0, %idx
|
||||
// CHECK-LABEL: func @fold_dim_of_tensor.cast
|
||||
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-NEXT: return %[[C4]], %[[T0]]
|
||||
func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
||||
func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%1 = dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2: index, index
|
||||
|
@ -1,5 +1,38 @@
|
||||
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @tensor.cast(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
|
||||
// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
|
||||
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]]
|
||||
// CHECK: return %[[RET]] : tensor<2xindex>
|
||||
func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
|
||||
%0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
|
||||
return %0 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor.cast_from_unranked(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
|
||||
// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
|
||||
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
|
||||
// CHECK: return %[[RET]] : tensor<2xf32>
|
||||
func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
|
||||
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor.cast_to_unranked(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
|
||||
// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
|
||||
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
|
||||
// CHECK: return %[[RET]] : tensor<*xf32>
|
||||
func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
|
||||
%0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @extract(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {
|
||||
|
@ -1,4 +1,66 @@
|
||||
// RUN: mlir-opt %s -canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
|
||||
|
||||
// Checks that NOP casts are removed.
|
||||
// CHECK-LABEL: cast_values
|
||||
func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
|
||||
// NOP cast
|
||||
%0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
|
||||
// CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32>
|
||||
%2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32>
|
||||
// NOP cast
|
||||
%4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32>
|
||||
// CHECK-NEXT: return %[[RET]] : tensor<2xi32>
|
||||
return %4 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tensor.cast_chain_ok
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
|
||||
func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
|
||||
// CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
|
||||
%0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32>
|
||||
%1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
|
||||
// CHECK-NEXT: return %[[RES]]
|
||||
return %1 : tensor<4x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tensor.cast_chain_regain
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
|
||||
func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32>
|
||||
%1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32>
|
||||
// CHECK-NEXT: return %[[IN]]
|
||||
return %1 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tensor.cast_chain_keep
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
|
||||
func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
|
||||
// CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
|
||||
%0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32>
|
||||
// CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
|
||||
%1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
|
||||
// CHECK-NEXT: return %[[C2]]
|
||||
return %1 : tensor<?x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tensor.cast_chain_invalid
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
|
||||
func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
|
||||
// CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
|
||||
%0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32>
|
||||
// CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
|
||||
%1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
|
||||
// CHECK-NEXT: return %[[C2]]
|
||||
return %1 : tensor<8x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
@ -31,3 +93,17 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
|
||||
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
|
||||
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_from_tensor.cast
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
|
||||
func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
%c0 = constant 0 : index
|
||||
// CHECK-NOT: tensor.cast
|
||||
%casted = tensor.cast %tensor : tensor<*xf32> to tensor<?xf32>
|
||||
// CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
|
||||
%result = tensor.extract %casted[%c0] : tensor<?xf32>
|
||||
return %result : f32
|
||||
}
|
||||
|
@ -1,4 +1,10 @@
|
||||
// RUN: mlir-opt <%s -verify-diagnostics
|
||||
// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
|
||||
|
||||
func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
|
||||
// expected-error@+1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}}
|
||||
%0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -1,5 +1,18 @@
|
||||
// RUN: mlir-opt <%s | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @cast(
|
||||
func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
|
||||
// CHECK: tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
// CHECK: tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
%1 = tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
// CHECK: tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
%2 = tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
// CHECK: tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%3 = tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @extract(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?x?x?xf32>,
|
||||
// CHECK-SAME: %[[INDEX:.*]]: index) {
|
||||
|
@ -696,23 +696,6 @@ func @tensor_from_elements() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast(%arg0
|
||||
func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
|
||||
// CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
%0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
|
||||
// CHECK: %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
%1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
|
||||
// CHECK: %2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
%2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
|
||||
// CHECK: %3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast(%arg0
|
||||
func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) {
|
||||
// CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
|
||||
|
@ -661,23 +661,15 @@ func @lowered_affine_ceildiv() -> (index, index) {
|
||||
|
||||
// Checks that NOP casts are removed.
|
||||
// CHECK-LABEL: cast_values
|
||||
func @cast_values(%arg0: tensor<*xi32>, %arg1: memref<?xi32>) -> (tensor<2xi32>, memref<2xi32>) {
|
||||
|
||||
// NOP casts
|
||||
%0 = tensor_cast %arg0 : tensor<*xi32> to tensor<*xi32>
|
||||
%1 = memref_cast %arg1 : memref<?xi32> to memref<?xi32>
|
||||
|
||||
// CHECK-NEXT: %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<2xi32>
|
||||
// CHECK-NEXT: %1 = memref_cast %arg1 : memref<?xi32> to memref<2xi32>
|
||||
%2 = tensor_cast %0 : tensor<*xi32> to tensor<2xi32>
|
||||
func @cast_values(%arg0: memref<?xi32>) -> memref<2xi32> {
|
||||
// NOP cast
|
||||
%1 = memref_cast %arg0 : memref<?xi32> to memref<?xi32>
|
||||
// CHECK-NEXT: %[[RET:.*]] = memref_cast %arg0 : memref<?xi32> to memref<2xi32>
|
||||
%3 = memref_cast %1 : memref<?xi32> to memref<2xi32>
|
||||
|
||||
// NOP casts
|
||||
%4 = tensor_cast %2 : tensor<2xi32> to tensor<2xi32>
|
||||
// NOP cast
|
||||
%5 = memref_cast %3 : memref<2xi32> to memref<2xi32>
|
||||
|
||||
// CHECK-NEXT: return %0, %1 : tensor<2xi32>, memref<2xi32>
|
||||
return %4, %5 : tensor<2xi32>, memref<2xi32>
|
||||
// CHECK-NEXT: return %[[RET]] : memref<2xi32>
|
||||
return %5 : memref<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -1121,61 +1113,12 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso
|
||||
yield %1 : index
|
||||
// CHECK: : tensor<3x?x5x7x?xindex>
|
||||
} : tensor<3x?x?x7x?xindex>
|
||||
// CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
|
||||
return %0 : tensor<3x?x?x7x?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tensor_cast_chain_ok
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
|
||||
func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
|
||||
// CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
|
||||
%0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32>
|
||||
%1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
|
||||
// CHECK-NEXT: return %[[RES]]
|
||||
return %1 : tensor<4x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tensor_cast_chain_regain
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
|
||||
func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32>
|
||||
%1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32>
|
||||
// CHECK-NEXT: return %[[IN]]
|
||||
return %1 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tensor_cast_chain_keep
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
|
||||
func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
|
||||
// CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
|
||||
%0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32>
|
||||
// CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
|
||||
%1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
|
||||
// CHECK-NEXT: return %[[C2]]
|
||||
return %1 : tensor<?x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tensor_cast_chain_invalid
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
|
||||
func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
|
||||
// CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
|
||||
%0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32>
|
||||
// CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
|
||||
%1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
|
||||
// CHECK-NEXT: return %[[C2]]
|
||||
return %1 : tensor<8x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @subtensor
|
||||
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
|
||||
func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
|
||||
@ -1189,30 +1132,16 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
|
||||
|
||||
// CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] :
|
||||
// CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32>
|
||||
// CHECK: tensor_cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
|
||||
%1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
|
||||
: tensor<8x16x4xf32> to tensor<?x?x?xf32>
|
||||
|
||||
// Test: subtensor with one dynamic operand can also be folded.
|
||||
// CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] :
|
||||
// CHECK-SAME: tensor<?x?x?xf32> to tensor<2x?x2xf32>
|
||||
// CHECK: tensor_cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
|
||||
%2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1]
|
||||
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||
|
||||
return %2 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_from_tensor_cast
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
|
||||
func @extract_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
%c0 = constant 0 : index
|
||||
// CHECK-NOT: tensor_cast
|
||||
%casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
|
||||
// CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
|
||||
%result = tensor.extract %casted[%c0] : tensor<?xf32>
|
||||
return %result : f32
|
||||
}
|
||||
|
@ -68,10 +68,10 @@ func @different_ops() -> (i32, i32) {
|
||||
/// types.
|
||||
// CHECK-LABEL: @different_results
|
||||
func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4x?xf32>) {
|
||||
// CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
// CHECK-NEXT: %1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
|
||||
%0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
%1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
|
||||
// CHECK: %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
// CHECK-NEXT: %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
|
||||
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
%1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
|
||||
|
||||
// CHECK-NEXT: return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
|
||||
return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
|
||||
|
@ -40,7 +40,7 @@ syn keyword mlirOps alloc alloca addf addi and call call_indirect cmpf cmpi
|
||||
syn keyword mlirOps constant dealloc divf dma_start dma_wait dim exp
|
||||
syn keyword mlirOps getTensor index_cast load log memref_cast
|
||||
syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp
|
||||
syn keyword mlirOps splat store select sqrt subf subi subview tanh tensor_cast
|
||||
syn keyword mlirOps splat store select sqrt subf subi subview tanh
|
||||
syn keyword mlirOps view
|
||||
|
||||
" Affine ops.
|
||||
|
Loading…
Reference in New Issue
Block a user