[mlir] Move std.tensor_cast -> tensor.cast.

This is almost entirely mechanical.

Differential Revision: https://reviews.llvm.org/D93357
This commit is contained in:
Sean Silva 2020-12-15 16:47:19 -08:00
parent a555ca8b3d
commit 129d6e554e
39 changed files with 500 additions and 471 deletions

View File

@ -34,7 +34,7 @@ namespace linalg {
class LinalgDependenceGraph;
/// A struct containing the Linalg producer before and after fusion.
/// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op
/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
/// before the consumer Linalg op, until enough canonicalizations have applied.
struct FusionInfo {
LinalgOp originalProducer;

View File

@ -354,31 +354,6 @@ computeRankReductionMask(ArrayRef<int64_t> originalShape,
/// ```
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
/// Determines whether TensorCastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor_cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
/// consume the results of tensor_cast operations. Such foldable tensor_cast
/// operations are typically inserted as `subtensor` ops and are canonicalized,
/// to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked tensors with same element type and rank.
/// 2. the tensor type has more static information than the result
///
/// Example:
/// ```mlir
/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
/// ```
bool canFoldIntoConsumerOp(TensorCastOp castOp);
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,

View File

@ -62,7 +62,7 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
let printer = [{
return printStandardCastOp(this->getOperation(), p);
}];
let verifier = [{ return ::verifyCastOp(*this); }];
let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
let hasFolder = 1;
}
@ -3428,56 +3428,6 @@ def TanhOp : FloatUnaryOp<"tanh"> {
}];
}
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//
def TensorCastOp : CastOp<"tensor_cast"> {
let summary = "tensor cast operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `std.tensor_cast` ssa-use `:` type `to` type
```
Convert a tensor from one type to an equivalent type without changing any
data elements. The source and destination types must both be tensor types
with the same element type. If both are ranked, then the rank should be the
same and static dimensions should match. The operation is invalid if
converting to a mismatching constant dimension.
Example:
```mlir
// Convert from unknown rank to rank 2 with unknown dimension sizes.
%2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
%2 = tensor_cast %1 : tensor<*xf32> to tensor<?x?xf32>
// Convert to a type with more known dimensions.
%3 = "std.tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
// Discard static dimension and rank information.
%4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
%5 = "std.tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
```
}];
let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor);
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a tensor_cast is always a tensor.
TensorType getType() { return getResult().getType().cast<TensorType>(); }
}];
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//

View File

@ -28,4 +28,38 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
//===----------------------------------------------------------------------===//
// Tensor Dialect Helpers
//===----------------------------------------------------------------------===//
namespace mlir {
namespace tensor {
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
/// consume the results of tensor.cast operations. Such foldable tensor.cast
/// operations are typically inserted as `subtensor` ops and are canonicalized,
/// to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked tensors with same element type and rank.
/// 2. the tensor type has more static information than the result
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
/// ```
bool canFoldIntoConsumerOp(CastOp castOp);
} // namespace tensor
} // namespace mlir
#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_

View File

@ -19,6 +19,52 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
let parser = [{ return ::parse$cppClass(parser, result); }];
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
let summary = "tensor cast operation";
let description = [{
Convert a tensor from one type to an equivalent type without changing any
data elements. The source and destination types must both be tensor types
with the same element type. If both are ranked, then the rank should be the
same and static dimensions should match. The operation is invalid if
converting to a mismatching constant dimension.
Example:
```mlir
// Convert from unknown rank to rank 2 with unknown dimension sizes.
%2 = tensor.cast %1 : tensor<*xf32> to tensor<?x?xf32>
// Convert to a type with more known dimensions.
%3 = tensor.cast %2 : tensor<?x?xf32> to tensor<4x?xf32>
// Discard static dimension and rank information.
%4 = tensor.cast %3 : tensor<4x?xf32> to tensor<?x?xf32>
%5 = tensor.cast %4 : tensor<?x?xf32> to tensor<*xf32>
```
}];
let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a tensor.cast is always a tensor.
TensorType getType() { return getResult().getType().cast<TensorType>(); }
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//

View File

@ -1775,11 +1775,18 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p);
// These functions are out-of-line implementations of the methods in CastOp,
// which avoids them being template instantiated/duplicated.
namespace impl {
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
// need for them, but some older ODS code in `std` still depends on them).
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType);
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
void printCastOp(Operation *op, OpAsmPrinter &p);
// TODO: Create a CastOpInterface with a method areCastCompatible.
// Also, consider adding functionality to CastOpInterface to be able to perform
// the ChainedTensorCast canonicalization generically.
Value foldCastOp(Operation *op);
LogicalResult verifyCastOp(Operation *op,
function_ref<bool(Type, Type)> areCastCompatible);
} // namespace impl
} // end namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -tensor-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
@ -8,7 +8,7 @@ func @main() {
%b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
%addf = addf %a, %b : tensor<3xf32>
%addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
%addf_unranked = tensor.cast %addf : tensor<3xf32> to tensor<*xf32>
call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
// CHECK-NEXT: [11, 22, 33]

View File

@ -1,4 +1,6 @@
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
// RUN: -finalizing-bufferize \
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@ -15,14 +17,14 @@ func @main() {
%inserted_at_position_0 = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
%inserted_at_position_1 = subtensor_insert %insert_val into %const[1][1][1] : tensor<1xf32> into tensor<2xf32>
%unranked_at_position_0 = tensor_cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
%unranked_at_position_0 = tensor.cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
call @print_memref_f32(%unranked_at_position_0) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
// CHECK-NEXT: [20, 10]
%unranked_at_position_1 = tensor_cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
%unranked_at_position_1 = tensor.cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
call @print_memref_f32(%unranked_at_position_1) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

View File

@ -1,4 +1,6 @@
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
// RUN: -finalizing-bufferize \
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@ -9,7 +11,7 @@ func @main() {
%insert_val = constant dense<20.0> : tensor<1xf32>
%inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
%unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32>
%unranked = tensor.cast %inserted : tensor<2xf32> to tensor<*xf32>
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

View File

@ -1,5 +1,5 @@
// RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize \
// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
// RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@ -19,7 +19,7 @@ func @main() {
// Note that this is skipping a step and we would need at least some function
// attribute to declare that this conversion is valid (e.g. when we statically
// know that things will play nicely at the C ABI boundary).
%unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
%unranked = tensor.cast %0 : tensor<4xf32> to tensor<*xf32>
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

View File

@ -1,12 +1,13 @@
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize \
// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
// RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,2,3" -linalg-bufferize \
// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -tensor-bufferize \
// RUN: -func-bufferize \
// RUN: -finalizing-bufferize -convert-linalg-to-loops -convert-scf-to-std \
// RUN: -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
@ -23,7 +24,7 @@ func @main() {
%D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
%unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32>
%unranked = tensor.cast %D : tensor<2x4xf32> to tensor<*xf32>
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

View File

@ -103,9 +103,9 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
@ -186,7 +186,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
Value tensor =
rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
return success();
}
@ -246,9 +246,9 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
@ -528,8 +528,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
// Materialize extent tensor.
Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
loc, rewriter.getIndexType(), extentValues);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
staticExtentTensor);
return success();
}
@ -561,8 +561,8 @@ public:
if (!adaptor.input().getType().isa<RankedTensorType>())
return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
adaptor.input());
return success();
}
};

View File

@ -16,4 +16,5 @@ add_mlir_dialect_library(MLIRLinalg
MLIRSideEffectInterfaces
MLIRViewLikeInterface
MLIRStandard
MLIRTensor
)

View File

@ -651,7 +651,7 @@ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
auto newOp =
rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
rewriter.getI64ArrayAttr(staticSizes));
rewriter.replaceOpWithNewOp<TensorCastOp>(op, op.getType(), newOp);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
};
@ -1815,12 +1815,12 @@ struct FoldTensorCastOp : public RewritePattern {
if (!linalgOp)
return failure();
// If no operand comes from a TensorCastOp and can be folded then fail.
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
if (v.isa<BlockArgument>())
return false;
auto castOp = v.getDefiningOp<TensorCastOp>();
auto castOp = v.getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
@ -1832,7 +1832,7 @@ struct FoldTensorCastOp : public RewritePattern {
newOperands.reserve(op->getNumOperands());
// Inputs may fold.
for (Value v : linalgOp.getInputs()) {
auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
newOperands.push_back(
canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
}
@ -1841,7 +1841,7 @@ struct FoldTensorCastOp : public RewritePattern {
linalgOp.getOutputBuffers().end());
// Init tensors may fold, in which case the resultType must also change.
for (Value v : linalgOp.getInitTensors()) {
auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
newResultTypes.push_back(newOperands.back().getType());

View File

@ -59,6 +59,7 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
void mlir::linalg::LinalgDialect::initialize() {
getContext()->getOrLoadDialect("std");
getContext()->getOrLoadDialect("tensor");
addTypes<RangeType>();
addOperations<

View File

@ -36,6 +36,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRStandard
MLIRStandardOpsTransforms
MLIRStandardToLLVM
MLIRTensor
MLIRTransforms
MLIRTransformUtils
MLIRVector

View File

@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
@ -517,13 +518,13 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
// Replace use.
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor_cast` op to propagate the transformation
// mismatches. Insert a `tensor.cast` op to propagate the transformation
// invariant that types are compatible.
Value def = fusedProducer->getResult(producerIdx);
OpOperand &use = consumer->getOpOperand(consumerIdx);
Type consumerType = use.get().getType();
if (consumerType != def.getType())
def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
use.set(def);
return FusionInfo{producerOp, fusedProducer};
}

View File

@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
@ -569,7 +570,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
SubViewOp::getCanonicalizationPatterns(patterns, ctx);
TensorCastOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
ViewOp::getCanonicalizationPatterns(patterns, ctx);
CanonicalizationPatternList<
#define GET_OP_LIST

View File

@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRShape
MLIRIR
MLIRSideEffectInterfaces
MLIRStandard
MLIRTensor
)

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"

View File

@ -1,5 +1,6 @@
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
def AllInputShapesEq : Constraint<CPred< [{
llvm::all_of($0, [&](mlir::Value val) {
@ -32,7 +33,7 @@ def SizeToIndexToSizeCanonicalization : Pat<
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
(replaceWithValue $arg)>;
// Fold tensor_cast(const_shape) to const_shape. This changes the type of
// Fold tensor.cast(const_shape) to const_shape. This changes the type of
// const_shape to the destination type of the cast.
def TensorCastConstShape : Pat <
(TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
(Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;

View File

@ -141,18 +141,6 @@ static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
<< op->getResult(0).getType();
}
/// A custom cast operation verifier.
template <typename T>
static LogicalResult verifyCastOp(T op) {
auto opType = op.getOperand().getType();
auto resType = op.getType();
if (!T::areCastCompatible(opType, resType))
return op.emitError("operand type ") << opType << " and result type "
<< resType << " are cast incompatible";
return success();
}
void StandardOpsDialect::initialize() {
getContext()->loadDialect<tensor::TensorDialect>();
addOperations<DmaStartOp, DmaWaitOp,
@ -1494,7 +1482,7 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<DimOfMemRefReshape, DimOfCastOp<TensorCastOp>>(context);
results.insert<DimOfMemRefReshape, DimOfCastOp<tensor::CastOp>>(context);
}
// ---------------------------------------------------------------------------
@ -1870,8 +1858,8 @@ struct StaticDynamicTensorFromElements
newOperands);
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
newOp.body().begin());
rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
newOp);
rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
newOp);
return success();
}
};
@ -1913,7 +1901,7 @@ struct ExtractFromDynamicTensorFromElements
/// Canonicalizes the pattern of the form
///
/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
///
/// to
@ -1924,7 +1912,7 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
auto tensorCast = extract.tensor().getDefiningOp<TensorCastOp>();
auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
if (!tensorCast)
return failure();
@ -3395,7 +3383,7 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
SubTensorOp newOp) {
rewriter.replaceOpWithNewOp<TensorCastOp>(op, newOp, op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
}
/// Pattern to rewrite a subview op with constant arguments.
@ -3536,60 +3524,6 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
return true;
}
/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
/// Determines whether TensorCastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor_cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
/// consume the results of tensor_cast operations. Such foldable tensor_cast
/// operations are typically inserted as `subtensor` ops and are canonicalized,
/// to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked tensors with same element type and rank.
/// 2. the tensor type has more static information than the result
///
/// Example:
/// ```mlir
/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
/// ```
bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
if (!castOp)
return false;
RankedTensorType sourceType =
castOp.source().getType().dyn_cast<RankedTensorType>();
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
// Requires RankedTensorType.
if (!sourceType || !resultType)
return false;
// Requires same elemental type.
if (sourceType.getElementType() != resultType.getElementType())
return false;
// Requires same rank.
if (sourceType.getRank() != resultType.getRank())
return false;
// If cast is towards more static sizes along any dimension, don't fold.
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
return false;
}
return true;
}
namespace {
/// Pattern to rewrite a subview op with MemRefCast arguments.
/// This essentially pushes memref_cast past its consuming subview when
@ -3857,107 +3791,6 @@ static LogicalResult verify(SubTensorInsertOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//
bool TensorCastOp::areCastCompatible(Type a, Type b) {
auto aT = a.dyn_cast<TensorType>();
auto bT = b.dyn_cast<TensorType>();
if (!aT || !bT)
return false;
if (aT.getElementType() != bT.getElementType())
return false;
return succeeded(verifyCompatibleShape(aT, bT));
}
OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
/// Compute a TensorType that has the joined shape knowledge of the two
/// given TensorTypes. The element types need to match.
static TensorType joinShapes(TensorType one, TensorType two) {
assert(one.getElementType() == two.getElementType());
if (!one.hasRank())
return two;
if (!two.hasRank())
return one;
int64_t rank = one.getRank();
if (rank != two.getRank())
return {};
SmallVector<int64_t, 4> join;
join.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
if (one.isDynamicDim(i)) {
join.push_back(two.getDimSize(i));
continue;
}
if (two.isDynamicDim(i)) {
join.push_back(one.getDimSize(i));
continue;
}
if (one.getDimSize(i) != two.getDimSize(i))
return {};
join.push_back(one.getDimSize(i));
}
return RankedTensorType::get(join, one.getElementType());
}
namespace {
/// Replaces chains of two tensor_cast operations by a single tensor_cast
/// operation if doing so does not remove runtime constraints.
struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> {
using OpRewritePattern<TensorCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorCastOp tensorCast,
PatternRewriter &rewriter) const final {
auto tensorCastOperand =
tensorCast.getOperand().getDefiningOp<TensorCastOp>();
if (!tensorCastOperand)
return failure();
auto sourceType =
tensorCastOperand.getOperand().getType().cast<TensorType>();
auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
auto resultType = tensorCast.getType().cast<TensorType>();
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
auto firstJoin =
joinShapes(joinShapes(sourceType, intermediateType), resultType);
// The join might not exist if the cast sequence would fail at runtime.
if (!firstJoin)
return failure();
// The newJoin always exists if the above join exists, it might just contain
// less information. If so, we cannot drop the intermediate cast, as doing
// so would remove runtime checks.
auto newJoin = joinShapes(sourceType, resultType);
if (firstJoin != newJoin)
return failure();
rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType,
tensorCastOperand.getOperand());
return success();
}
};
} // namespace
void TensorCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ChainedTensorCast>(context);
}
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//

View File

@ -117,20 +117,6 @@ public:
};
} // namespace
namespace {
class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
return success();
}
};
} // namespace
namespace {
class BufferizeTensorFromElementsOp
: public OpConversionPattern<TensorFromElementsOp> {
@ -162,7 +148,6 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context,
BufferizeDimOp,
BufferizeDynamicTensorFromElementsOp,
BufferizeSelectOp,
BufferizeTensorCastOp,
BufferizeTensorFromElementsOp
// clang-format on
>(typeConverter, context);
@ -180,8 +165,7 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
target.addLegalDialect<scf::SCFDialect>();
populateStdBufferizePatterns(context, typeConverter, patterns);
target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
TensorFromElementsOp>();
target.addIllegalOp<DynamicTensorFromElementsOp, TensorFromElementsOp>();
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).

View File

@ -8,12 +8,165 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::tensor;
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
/// consume the results of tensor.cast operations. Such foldable tensor.cast
/// operations are typically inserted as `subtensor` ops and are canonicalized,
/// to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked tensors with same element type and rank.
/// 2. the tensor type has more static information than the result
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
/// ```
bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
if (!castOp)
return false;
RankedTensorType sourceType =
castOp.source().getType().dyn_cast<RankedTensorType>();
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
// Requires RankedTensorType.
if (!sourceType || !resultType)
return false;
// Requires same elemental type.
if (sourceType.getElementType() != resultType.getElementType())
return false;
// Requires same rank.
if (sourceType.getRank() != resultType.getRank())
return false;
// If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
if (ShapedType::isDynamic(std::get<0>(t)) &&
!ShapedType::isDynamic(std::get<1>(t)))
return false;
}
return true;
}
bool CastOp::areCastCompatible(Type a, Type b) {
auto aT = a.dyn_cast<TensorType>();
auto bT = b.dyn_cast<TensorType>();
if (!aT || !bT)
return false;
if (aT.getElementType() != bT.getElementType())
return false;
return succeeded(verifyCompatibleShape(aT, bT));
}
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
/// Compute a TensorType that has the joined shape knowledge of the two
/// given TensorTypes. The element types need to match.
static TensorType joinShapes(TensorType one, TensorType two) {
assert(one.getElementType() == two.getElementType());
if (!one.hasRank())
return two;
if (!two.hasRank())
return one;
int64_t rank = one.getRank();
if (rank != two.getRank())
return {};
SmallVector<int64_t, 4> join;
join.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
if (one.isDynamicDim(i)) {
join.push_back(two.getDimSize(i));
continue;
}
if (two.isDynamicDim(i)) {
join.push_back(one.getDimSize(i));
continue;
}
if (one.getDimSize(i) != two.getDimSize(i))
return {};
join.push_back(one.getDimSize(i));
}
return RankedTensorType::get(join, one.getElementType());
}
namespace {
/// Replaces chains of two tensor.cast operations by a single tensor.cast
/// operation if doing so does not remove runtime constraints.
struct ChainedTensorCast : public OpRewritePattern<CastOp> {
using OpRewritePattern<CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CastOp tensorCast,
PatternRewriter &rewriter) const final {
auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
if (!tensorCastOperand)
return failure();
auto sourceType =
tensorCastOperand.getOperand().getType().cast<TensorType>();
auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
auto resultType = tensorCast.getType().cast<TensorType>();
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
auto firstJoin =
joinShapes(joinShapes(sourceType, intermediateType), resultType);
// The join might not exist if the cast sequence would fail at runtime.
if (!firstJoin)
return failure();
// The newJoin always exists if the above join exists, it might just contain
// less information. If so, we cannot drop the intermediate cast, as doing
// so would remove runtime checks.
auto newJoin = joinShapes(sourceType, resultType);
if (firstJoin != newJoin)
return failure();
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
tensorCastOperand.getOperand());
return success();
}
};
} // namespace
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ChainedTensorCast>(context);
}
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//

View File

@ -19,6 +19,20 @@
using namespace mlir;
namespace {
class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
return success();
}
};
} // namespace
namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public:
@ -37,7 +51,7 @@ public:
void mlir::populateTensorBufferizePatterns(
MLIRContext *context, BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<BufferizeExtractOp>(typeConverter, context);
patterns.insert<BufferizeCastOp, BufferizeExtractOp>(typeConverter, context);
}
namespace {
@ -49,7 +63,7 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
ConversionTarget target(*context);
populateTensorBufferizePatterns(context, typeConverter, patterns);
target.addIllegalOp<tensor::ExtractOp>();
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp>();
target.addLegalDialect<StandardOpsDialect>();
if (failed(

View File

@ -1213,6 +1213,19 @@ Value impl::foldCastOp(Operation *op) {
return nullptr;
}
LogicalResult
impl::verifyCastOp(Operation *op,
function_ref<bool(Type, Type)> areCastCompatible) {
auto opType = op->getOperand(0).getType();
auto resType = op->getResult(0).getType();
if (!areCastCompatible(opType, resType))
return op->emitError("operand type ")
<< opType << " and result type " << resType
<< " are cast incompatible";
return success();
}
//===----------------------------------------------------------------------===//
// Misc. utils
//===----------------------------------------------------------------------===//

View File

@ -95,7 +95,7 @@ func @const_shape() -> tensor<?xindex> {
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C3:.*]] = constant 3 : index
// CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]]
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
return %shape : tensor<?xindex>
@ -108,7 +108,7 @@ func @const_shape() -> tensor<?xindex> {
// CHECK-SAME: () -> tensor<?xindex>
func @const_shape_zero_elements() -> tensor<?xindex> {
// CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [] : tensor<?xindex>
return %shape : tensor<?xindex>
@ -152,13 +152,13 @@ func @const_size() -> index {
// -----
// Lower `to_extent_tensor` to `std.tensor_cast`
// Lower `to_extent_tensor` to `tensor.cast`
// Fold to_extent_tensor when already on tensor.
// CHECK-LABEL: @to_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex>
func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
// CHECK-NOT: to_extent_tensor
// CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
// CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
%casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex>
// CHECK: return %[[RES]]
return %casted : tensor<3xindex>
@ -316,8 +316,8 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
@ -356,8 +356,8 @@ func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xinde
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
@ -400,8 +400,8 @@ func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
@ -438,8 +438,8 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index

View File

@ -317,20 +317,20 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
// -----
// CHECK-LABEL: func @tensor_cast(
func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
// CHECK-LABEL: func @tensor.cast(
func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
-> tensor<3x?xf32>
{
%ta = tensor_cast %a : tensor<3x4xf32> to tensor<?x?xf32>
%tb = tensor_cast %b : tensor<4x?xf32> to tensor<?x?xf32>
%tc = tensor_cast %c : tensor<3x?xf32> to tensor<?x?xf32>
%ta = tensor.cast %a : tensor<3x4xf32> to tensor<?x?xf32>
%tb = tensor.cast %b : tensor<4x?xf32> to tensor<?x?xf32>
%tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32>
// CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
// CHECK-SAME: init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
%0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
init(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = tensor_cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
%1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
return %1: tensor<3x?xf32>
}
@ -360,7 +360,7 @@ func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
}
// CHECK: func @init_tensor_canonicalize
// CHECK: %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32>
// CHECK: %[[T1:.+]] = tensor_cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
// CHECK: return %[[T1]]
// -----

View File

@ -872,24 +872,24 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
// -----
// Verify that tensor_cast folding uses the correct type
// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned
func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
// Verify that tensor.cast folding uses the correct type
// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned
func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
// CHECK: constant dense<2> : tensor<1xindex>
// CHECK-NOT: tensor_cast
// CHECK-NOT: tensor.cast
%0 = shape.const_shape [2] : tensor<?xindex>
%1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
return %1 : tensor<1xindex>
}
// -----
// Verify that tensor_cast folding uses the correct type
// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic
func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
// Verify that tensor.cast folding uses the correct type
// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic
func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
// CHECK: shape.const_shape [2] : tensor<?xindex>
// CHECK-NOT: tensor_cast
// CHECK-NOT: tensor.cast
%0 = shape.const_shape [2] : tensor<1xindex>
%1 = tensor_cast %0 : tensor<1xindex> to tensor<?xindex>
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
return %1 : tensor<?xindex>
}

View File

@ -75,39 +75,6 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
return %0 : tensor<f32>
}
// CHECK-LABEL: func @tensor_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]]
// CHECK: return %[[RET]] : tensor<2xindex>
func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
%0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
return %0 : tensor<2xindex>
}
// CHECK-LABEL: func @tensor_cast_from_unranked(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
// CHECK: return %[[RET]] : tensor<2xf32>
func @tensor_cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
%0 = tensor_cast %arg0 : tensor<*xf32> to tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: func @tensor_cast_to_unranked(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
// CHECK: return %[[RET]] : tensor<*xf32>
func @tensor_cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
%0 = tensor_cast %arg0 : tensor<2xf32> to tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @tensor_from_elements(
// CHECK-SAME: %[[ELEM0:.*]]: index,
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {

View File

@ -116,17 +116,17 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
return %1 : index
}
// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx
// CHECK-LABEL: func @fold_dim_of_tensor_cast
// Test case: Folding dim(tensor.cast %0, %idx) -> dim %0, %idx
// CHECK-LABEL: func @fold_dim_of_tensor.cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK-NEXT: return %[[C4]], %[[T0]]
func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
%0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
%1 = dim %0, %c0 : tensor<?x?xf32>
%2 = dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index

View File

@ -1,5 +1,38 @@
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
// CHECK-LABEL: func @tensor.cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]]
// CHECK: return %[[RET]] : tensor<2xindex>
func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
%0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
return %0 : tensor<2xindex>
}
// CHECK-LABEL: func @tensor.cast_from_unranked(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
// CHECK: return %[[RET]] : tensor<2xf32>
func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: func @tensor.cast_to_unranked(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
// CHECK: return %[[RET]] : tensor<*xf32>
func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
%0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @extract(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {

View File

@ -1,4 +1,66 @@
// RUN: mlir-opt %s -canonicalize | FileCheck %s
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
// Checks that NOP casts are removed.
// CHECK-LABEL: cast_values
func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
// NOP cast
%0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
// CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32>
%2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32>
// NOP cast
%4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32>
// CHECK-NEXT: return %[[RET]] : tensor<2xi32>
return %4 : tensor<2xi32>
}
// -----
// CHECK-LABEL: @tensor.cast_chain_ok
// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
// CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
%0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32>
%1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
// CHECK-NEXT: return %[[RES]]
return %1 : tensor<4x8xi32>
}
// -----
// CHECK-LABEL: @tensor.cast_chain_regain
// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
%0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32>
%1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32>
// CHECK-NEXT: return %[[IN]]
return %1 : tensor<4xi32>
}
// -----
// CHECK-LABEL: @tensor.cast_chain_keep
// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
// CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
%0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32>
// CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
%1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
// CHECK-NEXT: return %[[C2]]
return %1 : tensor<?x8xi32>
}
// -----
// CHECK-LABEL: @tensor.cast_chain_invalid
// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
// CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
%0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32>
// CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
%1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
// CHECK-NEXT: return %[[C2]]
return %1 : tensor<8x4xi32>
}
// -----
@ -31,3 +93,17 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
}
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
%c0 = constant 0 : index
// CHECK-NOT: tensor.cast
%casted = tensor.cast %tensor : tensor<*xf32> to tensor<?xf32>
// CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
%result = tensor.extract %casted[%c0] : tensor<?xf32>
return %result : f32
}

View File

@ -1,4 +1,10 @@
// RUN: mlir-opt <%s -verify-diagnostics
// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
// expected-error@+1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}}
%0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32>
return
}
// -----

View File

@ -1,5 +1,18 @@
// RUN: mlir-opt <%s | mlir-opt | FileCheck %s
// CHECK-LABEL: func @cast(
func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
// CHECK: tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
// CHECK: tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
%1 = tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
// CHECK: tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
%2 = tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
// CHECK: tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
%3 = tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
return
}
// CHECK-LABEL: func @extract(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?x?x?xf32>,
// CHECK-SAME: %[[INDEX:.*]]: index) {

View File

@ -696,23 +696,6 @@ func @tensor_from_elements() {
return
}
// CHECK-LABEL: func @tensor_cast(%arg0
func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
// CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
%0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
// CHECK: %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
%1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
// CHECK: %2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
%2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
// CHECK: %3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
%3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
return
}
// CHECK-LABEL: func @memref_cast(%arg0
func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) {
// CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>

View File

@ -661,23 +661,15 @@ func @lowered_affine_ceildiv() -> (index, index) {
// Checks that NOP casts are removed.
// CHECK-LABEL: cast_values
func @cast_values(%arg0: tensor<*xi32>, %arg1: memref<?xi32>) -> (tensor<2xi32>, memref<2xi32>) {
// NOP casts
%0 = tensor_cast %arg0 : tensor<*xi32> to tensor<*xi32>
%1 = memref_cast %arg1 : memref<?xi32> to memref<?xi32>
// CHECK-NEXT: %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<2xi32>
// CHECK-NEXT: %1 = memref_cast %arg1 : memref<?xi32> to memref<2xi32>
%2 = tensor_cast %0 : tensor<*xi32> to tensor<2xi32>
func @cast_values(%arg0: memref<?xi32>) -> memref<2xi32> {
// NOP cast
%1 = memref_cast %arg0 : memref<?xi32> to memref<?xi32>
// CHECK-NEXT: %[[RET:.*]] = memref_cast %arg0 : memref<?xi32> to memref<2xi32>
%3 = memref_cast %1 : memref<?xi32> to memref<2xi32>
// NOP casts
%4 = tensor_cast %2 : tensor<2xi32> to tensor<2xi32>
// NOP cast
%5 = memref_cast %3 : memref<2xi32> to memref<2xi32>
// CHECK-NEXT: return %0, %1 : tensor<2xi32>, memref<2xi32>
return %4, %5 : tensor<2xi32>, memref<2xi32>
// CHECK-NEXT: return %[[RET]] : memref<2xi32>
return %5 : memref<2xi32>
}
// -----
@ -1121,61 +1113,12 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso
yield %1 : index
// CHECK: : tensor<3x?x5x7x?xindex>
} : tensor<3x?x?x7x?xindex>
// CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
// CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
return %0 : tensor<3x?x?x7x?xindex>
}
// -----
// CHECK-LABEL: @tensor_cast_chain_ok
// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
// CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
%0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32>
%1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
// CHECK-NEXT: return %[[RES]]
return %1 : tensor<4x8xi32>
}
// -----
// CHECK-LABEL: @tensor_cast_chain_regain
// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
%0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32>
%1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32>
// CHECK-NEXT: return %[[IN]]
return %1 : tensor<4xi32>
}
// -----
// CHECK-LABEL: @tensor_cast_chain_keep
// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
// CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
%0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32>
// CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
%1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
// CHECK-NEXT: return %[[C2]]
return %1 : tensor<?x8xi32>
}
// -----
// CHECK-LABEL: @tensor_cast_chain_invalid
// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
// CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
%0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32>
// CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
%1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
// CHECK-NEXT: return %[[C2]]
return %1 : tensor<8x4xi32>
}
// -----
// CHECK-LABEL: func @subtensor
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
@ -1189,30 +1132,16 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
// CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] :
// CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32>
// CHECK: tensor_cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
// CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
%1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
: tensor<8x16x4xf32> to tensor<?x?x?xf32>
// Test: subtensor with one dynamic operand can also be folded.
// CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] :
// CHECK-SAME: tensor<?x?x?xf32> to tensor<2x?x2xf32>
// CHECK: tensor_cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
// CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
%2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1]
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
return %2 : tensor<?x?x?xf32>
}
// -----
// CHECK-LABEL: func @extract_from_tensor_cast
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
%c0 = constant 0 : index
// CHECK-NOT: tensor_cast
%casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
// CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
%result = tensor.extract %casted[%c0] : tensor<?xf32>
return %result : f32
}

View File

@ -68,10 +68,10 @@ func @different_ops() -> (i32, i32) {
/// types.
// CHECK-LABEL: @different_results
func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4x?xf32>) {
// CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
// CHECK-NEXT: %1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
%0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
%1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
// CHECK: %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
// CHECK-NEXT: %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
%1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
// CHECK-NEXT: return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>

View File

@ -40,7 +40,7 @@ syn keyword mlirOps alloc alloca addf addi and call call_indirect cmpf cmpi
syn keyword mlirOps constant dealloc divf dma_start dma_wait dim exp
syn keyword mlirOps getTensor index_cast load log memref_cast
syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp
syn keyword mlirOps splat store select sqrt subf subi subview tanh tensor_cast
syn keyword mlirOps splat store select sqrt subf subi subview tanh
syn keyword mlirOps view
" Affine ops.