[ODS] Use Adaptor Trait for Shaped Type Inference

Author inferReturnTypeComponents methods with the Op Adaptor by using the InferShapedTypeOpAdaptor.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D155243
This commit is contained in:
Amanda Tang 2023-07-13 22:54:30 +00:00
parent 04cc892eed
commit 057fc8e7d8
6 changed files with 240 additions and 215 deletions

View File

@ -32,10 +32,7 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
//===----------------------------------------------------------------------===//
// Operator: argmax
//===----------------------------------------------------------------------===//
def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_ArgMaxOp : Tosa_Op<"argmax", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Perform argmax on the input.";
let description = [{
@ -62,10 +59,7 @@ def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
//===----------------------------------------------------------------------===//
// Operator: avg_pool2d
//===----------------------------------------------------------------------===//
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Performs max pooling on the input.";
let description = [{
@ -95,10 +89,7 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
//===----------------------------------------------------------------------===//
// Operator: conv2d
//===----------------------------------------------------------------------===//
def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_Conv2DOp : Tosa_Op<"conv2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "2D Convolution Operator";
let description = [{
@ -128,10 +119,7 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
//===----------------------------------------------------------------------===//
// Operator: conv3d
//===----------------------------------------------------------------------===//
def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_Conv3DOp : Tosa_Op<"conv3d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "3D Convolution operator";
let description = [{
@ -160,10 +148,8 @@ def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
//===----------------------------------------------------------------------===//
// Operator: depthwise_conv2d
//===----------------------------------------------------------------------===//
def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d",
[InferShapedTypeOpAdaptor, Pure]> {
let summary = "Depthwise 2D Convolution operator";
let description = [{
@ -193,10 +179,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
//===----------------------------------------------------------------------===//
// Operator: fft2d
//===----------------------------------------------------------------------===//
def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_FFT2dOp : Tosa_Op<"fft2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Performs FFT2D operation on the input.";
let description = [{
@ -224,9 +207,7 @@ def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
// Operator: fully_connected
//===----------------------------------------------------------------------===//
def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
InferShapedTypeOpAdaptor, Pure]> {
let summary = "Fully Connected operator";
let description = [{
@ -251,10 +232,7 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
//===----------------------------------------------------------------------===//
// Operator: matmul
//===----------------------------------------------------------------------===//
def Tosa_MatMulOp : Tosa_Op<"matmul", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_MatMulOp : Tosa_Op<"matmul", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Matrix multiplication with bias";
let description = [{
@ -279,10 +257,7 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [
//===----------------------------------------------------------------------===//
// Operator: max_pool2d
//===----------------------------------------------------------------------===//
def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Performs max pooling on the input.";
let description = [{
@ -310,10 +285,7 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
//===----------------------------------------------------------------------===//
// Operator: rfft2d
//===----------------------------------------------------------------------===//
def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Performs RFFT2D operation on the input.";
let description = [{
@ -338,10 +310,8 @@ def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
//===----------------------------------------------------------------------===//
// Operator: transpose_conv2d
//===----------------------------------------------------------------------===//
def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d",
[InferShapedTypeOpAdaptor, Pure]> {
let summary = "Transpose 2D Convolution operator.";
let description = [{
@ -828,10 +798,7 @@ def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> {
//===----------------------------------------------------------------------===//
// Operator: table
//===----------------------------------------------------------------------===//
def Tosa_TableOp : Tosa_Op<"table", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_TableOp : Tosa_Op<"table", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Table lookup op";
let description = [{
@ -1214,7 +1181,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
// Operator: reduce_all
//===----------------------------------------------------------------------===//
def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
InferTensorType, Pure]> {
InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce All operator";
let description = [{
@ -1243,7 +1210,7 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
// Operator: reduce_any
//===----------------------------------------------------------------------===//
def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
InferTensorType, Pure]> {
InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Any operator";
let description = [{
@ -1272,7 +1239,7 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
// Operator: reduce_max
//===----------------------------------------------------------------------===//
def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
InferTensorType, Pure]> {
InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Max operator";
let description = [{
@ -1301,7 +1268,7 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
// Operator: reduce_min
//===----------------------------------------------------------------------===//
def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
InferTensorType, Pure]> {
InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Min operator";
let description = [{
@ -1330,7 +1297,7 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
// Operator: reduce_prod
//===----------------------------------------------------------------------===//
def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
InferTensorType, Pure]> {
InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Prod operator";
let description = [{
@ -1359,7 +1326,7 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
// Operator: reduce_sum
//===----------------------------------------------------------------------===//
def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
InferTensorType, Pure]> {
InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Sum operator";
let description = [{
@ -1393,7 +1360,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
// Operator: concat
//===----------------------------------------------------------------------===//
def Tosa_ConcatOp : Tosa_Op<"concat", [
InferTensorType, Pure]> {
InferTensorTypeAdaptor, Pure]> {
let summary = "Concatenates tensors along one dimension.";
let description = [{
@ -1423,10 +1390,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
//===----------------------------------------------------------------------===//
// Operator: pad
//===----------------------------------------------------------------------===//
def Tosa_PadOp : Tosa_Op<"pad", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_PadOp : Tosa_Op<"pad", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Pads a tensor with value specified.";
let description = [{
@ -1471,7 +1435,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
// Operator: reshape
//===----------------------------------------------------------------------===//
def Tosa_ReshapeOp: Tosa_Op<"reshape", [
InferTensorType, Pure]> {
InferTensorTypeAdaptor, Pure]> {
let summary = "Reshape operator";
let description = [{
@ -1528,9 +1492,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
//===----------------------------------------------------------------------===//
// Operator: slice
//===----------------------------------------------------------------------===//
def Tosa_SliceOp: Tosa_Op<"slice", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>, Pure]> {
def Tosa_SliceOp: Tosa_Op<"slice", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Slice operator";
let description = [{
@ -1556,10 +1518,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
//===----------------------------------------------------------------------===//
// Operator: tile
//===----------------------------------------------------------------------===//
def Tosa_TileOp: Tosa_Op<"tile", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_TileOp: Tosa_Op<"tile", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Tile operator";
let description = [{
@ -1580,10 +1539,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [
//===----------------------------------------------------------------------===//
// Operator: transpose
//===----------------------------------------------------------------------===//
def Tosa_TransposeOp : Tosa_Op<"transpose", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_TransposeOp : Tosa_Op<"transpose", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Transpose operator";
let description = [{
@ -1615,10 +1571,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
//===----------------------------------------------------------------------===//
// Operator: gather
//===----------------------------------------------------------------------===//
def Tosa_GatherOp : Tosa_Op<"gather", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_GatherOp : Tosa_Op<"gather", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Gather operation,";
let description = [{
@ -1639,10 +1592,7 @@ def Tosa_GatherOp : Tosa_Op<"gather", [
//===----------------------------------------------------------------------===//
// Operator: scatter
//===----------------------------------------------------------------------===//
def Tosa_ScatterOp : Tosa_Op<"scatter", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_ScatterOp : Tosa_Op<"scatter", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Scatter operation,";
let description = [{
@ -1669,10 +1619,7 @@ def Tosa_ScatterOp : Tosa_Op<"scatter", [
//===----------------------------------------------------------------------===//
// Operator: resize
//===----------------------------------------------------------------------===//
def Tosa_ResizeOp : Tosa_Op<"resize", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
def Tosa_ResizeOp : Tosa_Op<"resize", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Resize operation, supports various resize/upsample modes";
@ -1898,9 +1845,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
//===----------------------------------------------------------------------===//
// Further described in docs/Rationale/RationaleTOSADialect.md .
//===----------------------------------------------------------------------===//
def Tosa_IfOp : Tosa_Op<"cond_if", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
def Tosa_IfOp : Tosa_Op<"cond_if",
[InferShapedTypeOpAdaptor,
SingleBlockImplicitTerminator<"YieldOp">,
RecursiveMemoryEffects]> {
let summary = "Conditional if operator";
@ -1933,8 +1879,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if", [
//===----------------------------------------------------------------------===//
def Tosa_WhileOp : Tosa_Op<"while_loop", [
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
InferShapedTypeOpAdaptor,
SingleBlockImplicitTerminator<"YieldOp">,
RecursiveMemoryEffects]> {
let summary = "output = input; While (Cond(output)) {output = Body(output)}";

View File

@ -262,6 +262,10 @@ template <typename ConcreteType>
class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
};
template <typename ConcreteType>
class InferShapedTypeOpAdaptor
: public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.
/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.

View File

@ -222,6 +222,42 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
}]
>;
// Convenient trait to define a wrapper to inferReturnTypeComponents that passes
// in the Op Adaptor directly
class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
[
// Op implements infer type op interface.
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
NativeOpTrait<
/*name=*/"InferShapedTypeOpAdaptor",
/*traits=*/[],
/*extraOpDeclaration=*/[{
static ::mlir::LogicalResult
inferReturnTypeComponents(::mlir::MLIRContext *context,
std::optional<::mlir::Location> location,
Adaptor adaptor,
::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes);
}],
/*extraOpDefinition=*/[{
::mlir::LogicalResult
$cppClass::inferReturnTypeComponents(::mlir::MLIRContext *context,
std::optional<::mlir::Location> location,
::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes,
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes) {
$cppClass::Adaptor adaptor(operands, attributes, properties, regions);
return $cppClass::inferReturnTypeComponents(context,
location, adaptor, inferredReturnShapes);
}
}]
>
]>;
def InferShapedTypeOpAdaptor : InferShapedTypeOpAdaptorBase<[
"inferReturnTypeComponents"]>;
def InferShapedTypeOpAdaptorWithReify : InferShapedTypeOpAdaptorBase<[
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
// Convenience class grouping together type and shaped type op interfaces for
// ops that have tensor return types.
class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
@ -260,6 +296,44 @@ def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>;
def InferTensorTypeWithReify: InferTensorTypeBase<[
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
// Convenience class grouping together type and shaped type op interfaces for
// ops that have tensor return types.
class InferTensorTypeAdaptorBase<list<string> overridenMethods = []> : TraitList<
[
// Op implements infer type op interface.
DeclareOpInterfaceMethods<InferTypeOpInterface>,
// The op will have methods implementing the ShapedType type inference
// interface.
InferShapedTypeOpAdaptorBase<overridenMethods>,
// The op produces tensors and will use the ShapedType type infer interface
// along with knowledge that it is producing Tensors to infer the type.
NativeOpTrait<
/*name=*/"InferTensorType",
/*traits=*/[],
/*extraOpDeclaration=*/[{}],
/*extraOpDefinition=*/[{
LogicalResult
$cppClass::inferReturnTypes(::mlir::MLIRContext *context,
std::optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
if (failed($cppClass::inferReturnTypeComponents(context, location,
operands, attributes, properties, regions,
retComponents)))
return failure();
return ::mlir::detail::inferReturnTensorTypes(retComponents,
inferredReturnTypes);
}
}]
>
]>;
def InferTensorTypeAdaptor : InferTensorTypeAdaptorBase<["inferReturnTypeComponents"]>;
def InferTensorTypeAdaptorWithReify: InferTensorTypeAdaptorBase<[
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
def ReifyRankedShapedTypeOpInterface :
OpInterface<"ReifyRankedShapedTypeOpInterface"> {
let description = [{

View File

@ -22,6 +22,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
@ -404,12 +405,10 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
ArgMaxOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
auto *prop = properties.as<Properties *>();
IntegerAttr axis = prop->axis;
ShapeAdaptor inputShape(adaptor.getInput().getType());
IntegerAttr axis = adaptor.getProperties().axis;
int32_t axisVal = axis.getValue().getSExtValue();
if (!inputShape.hasRank()) {
@ -431,10 +430,9 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
RFFT2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor inputShape(adaptor.getInput().getType());
if (!inputShape.hasRank())
return failure();
@ -458,26 +456,26 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
FFT2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
inferredReturnShapes.push_back(
ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
inferredReturnShapes.push_back(
ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
return success();
}
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
ConcatOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Infer all dimension sizes by reducing based on inputs.
auto *prop = properties.as<Properties *>();
int32_t axis = prop->axis.getValue().getSExtValue();
const Properties &prop = adaptor.getProperties();
int32_t axis = prop.axis.getValue().getSExtValue();
llvm::SmallVector<int64_t> outputShape;
bool hasRankedInput = false;
for (auto operand : operands) {
ShapeAdaptor operandShape = operands.getShape(operand);
for (auto operand : adaptor.getOperands()) {
ShapeAdaptor operandShape(operand.getType());
if (!operandShape.hasRank())
continue;
@ -501,7 +499,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
hasRankedInput = true;
}
Type inputType =
llvm::cast<TensorType>(operands.getType()[0]).getElementType();
llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
if (!hasRankedInput) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
@ -509,8 +507,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
// Determine the dimension size along the concatenation axis.
int64_t concatDimSize = 0;
for (auto operand : operands) {
ShapeAdaptor operandShape = operands.getShape(operand);
for (auto operand : adaptor.getOperands()) {
ShapeAdaptor operandShape(operand.getType());
// We need to know the length of the concatenation axis of all inputs to
// determine the dimension size of the output shape.
@ -553,12 +551,11 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
FullyConnectedOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor weightShape = operands.getShape(1);
ShapeAdaptor biasShape = operands.getShape(2);
ShapeAdaptor inputShape(adaptor.getInput().getType());
ShapeAdaptor weightShape(adaptor.getWeight().getType());
ShapeAdaptor biasShape(adaptor.getBias().getType());
// All shapes are dynamic.
SmallVector<int64_t> outShape;
@ -585,11 +582,10 @@ LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
MatMulOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor lhsShape = operands.getShape(0);
ShapeAdaptor rhsShape = operands.getShape(1);
ShapeAdaptor lhsShape(adaptor.getA().getType());
ShapeAdaptor rhsShape(adaptor.getB().getType());
// All shapes are dynamic.
SmallVector<int64_t> outShape;
@ -612,11 +608,10 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
PadOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor paddingShape = operands.getShape(1);
ShapeAdaptor inputShape(adaptor.getInput1().getType());
ShapeAdaptor paddingShape(adaptor.getPadding().getType());
SmallVector<int64_t> outputShape;
// If both inputs have unknown shape, we cannot determine the shape of the
@ -641,7 +636,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
DenseIntElementsAttr paddings;
// If the paddings value is not a constant, all dimensions must be dynamic.
if (!matchPattern(operands[1], m_Constant(&paddings))) {
if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
@ -675,22 +670,18 @@ static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
inferredReturnShapes.push_back(ShapedTypeComponents(
convertToMlirShape(SliceOpAdaptor(operands, attributes,
*properties.as<Properties *>(), regions)
.getSize())));
inferredReturnShapes.push_back(
ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
return success();
}
LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
TableOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor inputShape(adaptor.getInput().getType());
if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
@ -704,13 +695,10 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
TileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
TileOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
regions);
ArrayRef<int64_t> multiples = adaptor.getMultiples();
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor inputShape(adaptor.getInput1().getType());
SmallVector<int64_t> outputShape;
if (!inputShape.hasRank()) {
outputShape.resize(multiples.size(), ShapedType::kDynamic);
@ -739,13 +727,10 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
ReshapeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ReshapeOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
regions);
ShapeAdaptor inputShape = operands.getShape(0);
Type inputType = getElementTypeOrSelf(operands.getType()[0]);
ShapeAdaptor inputShape(adaptor.getInput1().getType());
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
llvm::SmallVector<int64_t> newShapeValue =
convertToMlirShape(adaptor.getNewShape());
@ -814,11 +799,10 @@ LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor permsShape = operands.getShape(1);
ShapeAdaptor inputShape(adaptor.getInput1().getType());
ShapeAdaptor permsShape(adaptor.getPerms().getType());
// If input rank and permutation length is unknown, the output rank is
// unknown.
@ -869,7 +853,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
// If the permuations are a constant we can directly determine the output
// shape.
if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
DenseIntElementsAttr attr;
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
attr.getType().getRank() == 1) {
ShapeAdaptor permShape = attr;
outputShape.reserve(inputShape.getRank());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
@ -882,19 +869,18 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
GatherOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(3, ShapedType::kDynamic);
ShapeAdaptor valuesShape = operands.getShape(0);
ShapeAdaptor valuesShape(adaptor.getValues().getType());
if (valuesShape.hasRank()) {
outputShape[0] = valuesShape.getDimSize(0);
outputShape[2] = valuesShape.getDimSize(2);
}
ShapeAdaptor indicesShape = operands.getShape(1);
ShapeAdaptor indicesShape(adaptor.getIndices().getType());
if (indicesShape.hasRank()) {
if (outputShape[0] == ShapedType::kDynamic)
outputShape[0] = indicesShape.getDimSize(0);
@ -908,15 +894,12 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
ResizeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ResizeOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
regions);
llvm::SmallVector<int64_t, 4> outputShape;
outputShape.resize(4, ShapedType::kDynamic);
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
ShapeAdaptor inputShape(adaptor.getInput().getType());
if (!inputShape.hasRank())
return failure();
@ -950,26 +933,25 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
ScatterOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(3, ShapedType::kDynamic);
ShapeAdaptor valuesInShape = operands.getShape(0);
ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
if (valuesInShape.hasRank()) {
outputShape[0] = valuesInShape.getDimSize(0);
outputShape[1] = valuesInShape.getDimSize(1);
outputShape[2] = valuesInShape.getDimSize(2);
}
ShapeAdaptor indicesShape = operands.getShape(1);
ShapeAdaptor indicesShape(adaptor.getIndices().getType());
if (indicesShape.hasRank()) {
if (outputShape[0] == ShapedType::kDynamic)
outputShape[0] = indicesShape.getDimSize(0);
}
ShapeAdaptor inputShape = operands.getShape(2);
ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
if (outputShape[0] == ShapedType::kDynamic)
outputShape[0] = inputShape.getDimSize(0);
@ -1009,13 +991,13 @@ static LogicalResult ReduceInferReturnTypes(
#define REDUCE_SHAPE_INFER(OP) \
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
ValueShapeRange operands, DictionaryAttr attributes, \
OpaqueProperties properties, RegionRange regions, \
OP::Adaptor adaptor, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
Type inputType = \
llvm::cast<TensorType>(operands.getType()[0]).getElementType(); \
return ReduceInferReturnTypes(operands.getShape(0), inputType, \
properties.as<Properties *>()->axis, \
llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
ShapeAdaptor inputShape(adaptor.getInput().getType()); \
const Properties &prop = adaptor.getProperties(); \
return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
inferredReturnShapes); \
} \
COMPATIBLE_RETURN_TYPES(OP)
@ -1092,10 +1074,9 @@ NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER
static LogicalResult poolingInferReturnTypes(
const ValueShapeRange &operands, DictionaryAttr attributes,
ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride, ArrayRef<int64_t> pad,
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
ArrayRef<int64_t> pad,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(4, ShapedType::kDynamic);
@ -1128,12 +1109,9 @@ static LogicalResult poolingInferReturnTypes(
LogicalResult Conv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
Conv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
Conv2DOp::Adaptor adaptor(operands, attributes,
*properties.as<Properties *>(), regions);
int64_t inputWidth = ShapedType::kDynamic;
int64_t inputHeight = ShapedType::kDynamic;
@ -1142,7 +1120,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
// Input shape describes input width/height and batch.
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0);
inputHeight = inputShape.getDimSize(1);
@ -1150,7 +1128,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
ShapeAdaptor weightShape(adaptor.getWeight().getType());
if (weightShape.hasRank()) {
outputShape[3] = weightShape.getDimSize(0);
weightHeight = weightShape.getDimSize(1);
@ -1158,7 +1136,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
ShapeAdaptor biasShape(adaptor.getBias().getType());
if (biasShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? biasShape.getDimSize(0)
@ -1193,12 +1171,9 @@ LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
Conv3DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
Conv3DOp::Adaptor adaptor(operands, attributes,
*properties.as<Properties *>(), regions);
int64_t inputWidth = ShapedType::kDynamic;
int64_t inputHeight = ShapedType::kDynamic;
@ -1209,7 +1184,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
int64_t weightDepth = ShapedType::kDynamic;
// Input shape describes input width/height and batch.
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0);
inputDepth = inputShape.getDimSize(1);
@ -1218,7 +1193,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
ShapeAdaptor weightShape(adaptor.getWeight().getType());
if (weightShape.hasRank()) {
outputShape[4] = weightShape.getDimSize(0);
weightDepth = weightShape.getDimSize(1);
@ -1227,7 +1202,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
ShapeAdaptor biasShape(adaptor.getBias().getType());
if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
outputShape[4] = biasShape.getDimSize(0);
}
@ -1268,32 +1243,29 @@ LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
AvgPool2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
Properties &prop = *properties.as<Properties *>();
return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
prop.pad, inferredReturnShapes);
ShapeAdaptor inputShape(adaptor.getInput().getType());
const Properties &prop = adaptor.getProperties();
return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
inferredReturnShapes);
}
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
MaxPool2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
Properties &prop = *properties.as<Properties *>();
return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
prop.pad, inferredReturnShapes);
ShapeAdaptor inputShape(adaptor.getInput().getType());
const Properties &prop = adaptor.getProperties();
return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
inferredReturnShapes);
}
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
DepthwiseConv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
DepthwiseConv2DOp::Adaptor adaptor(operands, attributes,
*properties.as<Properties *>(), regions);
int64_t inputWidth = ShapedType::kDynamic;
int64_t inputHeight = ShapedType::kDynamic;
@ -1304,7 +1276,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
int64_t depthChannels = ShapedType::kDynamic;
// Input shape describes input width/height and batch.
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0);
inputHeight = inputShape.getDimSize(1);
@ -1313,7 +1285,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
ShapeAdaptor weightShape(adaptor.getWeight().getType());
if (weightShape.hasRank()) {
weightHeight = weightShape.getDimSize(0);
weightWidth = weightShape.getDimSize(1);
@ -1331,7 +1303,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
ShapeAdaptor biasShape(adaptor.getBias().getType());
if (biasShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? biasShape.getDimSize(0)
@ -1366,11 +1338,8 @@ LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
TransposeConv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
TransposeConv2DOp::Adaptor adaptor(operands, attributes,
*properties.as<Properties *>(), regions);
// outputShape is mutable.
llvm::SmallVector<int64_t> outputShape =
convertToMlirShape(adaptor.getOutShape());
@ -1381,7 +1350,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
int64_t weightHeight = ShapedType::kDynamic;
// Input shape describes input width/height and batch.
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
outputShape[0] = ShapedType::isDynamic(outputShape[0])
? inputShape.getDimSize(0)
@ -1391,7 +1360,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter());
ShapeAdaptor weightShape(adaptor.getFilter().getType());
if (weightShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? weightShape.getDimSize(0)
@ -1401,7 +1370,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
ShapeAdaptor biasShape = operands.getShape(adaptor.getInput());
ShapeAdaptor biasShape(adaptor.getInput().getType());
if (biasShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? biasShape.getDimSize(0)
@ -1433,11 +1402,10 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
IfOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
for (Region *region : regions) {
for (Region *region : adaptor.getRegions()) {
for (auto &block : *region)
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);
@ -1478,11 +1446,10 @@ LogicalResult IfOp::inferReturnTypeComponents(
LogicalResult WhileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
WhileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
for (auto &block : *regions[1])
for (auto &block : adaptor.getBody())
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);

View File

@ -1437,9 +1437,8 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
// Create return type consisting of the last element of the first operand.
auto operandType = operands.front().getType();
auto sval = dyn_cast<ShapedType>(operandType);
if (!sval) {
if (!sval)
return emitOptionalError(location, "only shaped type operands allowed");
}
int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
auto type = IntegerType::get(context, 17);
@ -1458,6 +1457,35 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
return success();
}
LogicalResult
OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
auto operandType = adaptor.getOperand1().getType();
auto sval = dyn_cast<ShapedType>(operandType);
if (!sval)
return emitOptionalError(location, "only shaped type operands allowed");
int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
auto type = IntegerType::get(context, 17);
Attribute encoding;
if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
encoding = rankedTy.getEncoding();
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
return success();
}
LogicalResult
OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {
shapes = SmallVector<Value, 1>{
builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
return success();
}
LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {

View File

@ -780,6 +780,13 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
let results = (outs AnyTensor);
}
def OpWithShapedTypeInferTypeAdaptorInterfaceOp :
TEST_Op<"op_with_shaped_type_infer_type_adaptor_if",
[InferTensorTypeAdaptorWithReify]> {
let arguments = (ins AnyTensor:$operand1, AnyTensor:$operand2);
let results = (outs AnyTensor:$result);
}
def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]> {