mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-02 18:58:15 +00:00
[ODS] Use Adaptor Trait for Shaped Type Inference
Author inferReturnTypeComponents methods with the Op Adaptor by using the InferShapedTypeOpAdaptor. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D155243
This commit is contained in:
parent
04cc892eed
commit
057fc8e7d8
@ -32,10 +32,7 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: argmax
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_ArgMaxOp : Tosa_Op<"argmax", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Perform argmax on the input.";
|
||||
|
||||
let description = [{
|
||||
@ -62,10 +59,7 @@ def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: avg_pool2d
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Performs max pooling on the input.";
|
||||
|
||||
let description = [{
|
||||
@ -95,10 +89,7 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: conv2d
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_Conv2DOp : Tosa_Op<"conv2d", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "2D Convolution Operator";
|
||||
|
||||
let description = [{
|
||||
@ -128,10 +119,7 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: conv3d
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_Conv3DOp : Tosa_Op<"conv3d", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "3D Convolution operator";
|
||||
|
||||
let description = [{
|
||||
@ -160,10 +148,8 @@ def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: depthwise_conv2d
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d",
|
||||
[InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Depthwise 2D Convolution operator";
|
||||
|
||||
let description = [{
|
||||
@ -193,10 +179,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: fft2d
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_FFT2dOp : Tosa_Op<"fft2d", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Performs FFT2D operation on the input.";
|
||||
|
||||
let description = [{
|
||||
@ -224,9 +207,7 @@ def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
|
||||
// Operator: fully_connected
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Fully Connected operator";
|
||||
|
||||
let description = [{
|
||||
@ -251,10 +232,7 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: matmul
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_MatMulOp : Tosa_Op<"matmul", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_MatMulOp : Tosa_Op<"matmul", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Matrix multiplication with bias";
|
||||
|
||||
let description = [{
|
||||
@ -279,10 +257,7 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: max_pool2d
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Performs max pooling on the input.";
|
||||
|
||||
let description = [{
|
||||
@ -310,10 +285,7 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: rfft2d
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Performs RFFT2D operation on the input.";
|
||||
|
||||
let description = [{
|
||||
@ -338,10 +310,8 @@ def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: transpose_conv2d
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d",
|
||||
[InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Transpose 2D Convolution operator.";
|
||||
|
||||
let description = [{
|
||||
@ -828,10 +798,7 @@ def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: table
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_TableOp : Tosa_Op<"table", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_TableOp : Tosa_Op<"table", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Table lookup op";
|
||||
|
||||
let description = [{
|
||||
@ -1214,7 +1181,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
|
||||
// Operator: reduce_all
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
|
||||
InferTensorType, Pure]> {
|
||||
InferTensorTypeAdaptor, Pure]> {
|
||||
let summary = "Reduce All operator";
|
||||
|
||||
let description = [{
|
||||
@ -1243,7 +1210,7 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
|
||||
// Operator: reduce_any
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
|
||||
InferTensorType, Pure]> {
|
||||
InferTensorTypeAdaptor, Pure]> {
|
||||
let summary = "Reduce Any operator";
|
||||
|
||||
let description = [{
|
||||
@ -1272,7 +1239,7 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
|
||||
// Operator: reduce_max
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
|
||||
InferTensorType, Pure]> {
|
||||
InferTensorTypeAdaptor, Pure]> {
|
||||
let summary = "Reduce Max operator";
|
||||
|
||||
let description = [{
|
||||
@ -1301,7 +1268,7 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
|
||||
// Operator: reduce_min
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
|
||||
InferTensorType, Pure]> {
|
||||
InferTensorTypeAdaptor, Pure]> {
|
||||
let summary = "Reduce Min operator";
|
||||
|
||||
let description = [{
|
||||
@ -1330,7 +1297,7 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
|
||||
// Operator: reduce_prod
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
|
||||
InferTensorType, Pure]> {
|
||||
InferTensorTypeAdaptor, Pure]> {
|
||||
let summary = "Reduce Prod operator";
|
||||
|
||||
let description = [{
|
||||
@ -1359,7 +1326,7 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
|
||||
// Operator: reduce_sum
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
|
||||
InferTensorType, Pure]> {
|
||||
InferTensorTypeAdaptor, Pure]> {
|
||||
let summary = "Reduce Sum operator";
|
||||
|
||||
let description = [{
|
||||
@ -1393,7 +1360,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
|
||||
// Operator: concat
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ConcatOp : Tosa_Op<"concat", [
|
||||
InferTensorType, Pure]> {
|
||||
InferTensorTypeAdaptor, Pure]> {
|
||||
let summary = "Concatenates tensors along one dimension.";
|
||||
|
||||
let description = [{
|
||||
@ -1423,10 +1390,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: pad
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_PadOp : Tosa_Op<"pad", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_PadOp : Tosa_Op<"pad", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Pads a tensor with value specified.";
|
||||
|
||||
let description = [{
|
||||
@ -1471,7 +1435,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
|
||||
// Operator: reshape
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ReshapeOp: Tosa_Op<"reshape", [
|
||||
InferTensorType, Pure]> {
|
||||
InferTensorTypeAdaptor, Pure]> {
|
||||
let summary = "Reshape operator";
|
||||
|
||||
let description = [{
|
||||
@ -1528,9 +1492,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: slice
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_SliceOp: Tosa_Op<"slice", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>, Pure]> {
|
||||
def Tosa_SliceOp: Tosa_Op<"slice", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Slice operator";
|
||||
|
||||
let description = [{
|
||||
@ -1556,10 +1518,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: tile
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_TileOp: Tosa_Op<"tile", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_TileOp: Tosa_Op<"tile", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Tile operator";
|
||||
|
||||
let description = [{
|
||||
@ -1580,10 +1539,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: transpose
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_TransposeOp : Tosa_Op<"transpose", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_TransposeOp : Tosa_Op<"transpose", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Transpose operator";
|
||||
|
||||
let description = [{
|
||||
@ -1615,10 +1571,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: gather
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_GatherOp : Tosa_Op<"gather", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_GatherOp : Tosa_Op<"gather", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Gather operation,";
|
||||
|
||||
let description = [{
|
||||
@ -1639,10 +1592,7 @@ def Tosa_GatherOp : Tosa_Op<"gather", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: scatter
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ScatterOp : Tosa_Op<"scatter", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_ScatterOp : Tosa_Op<"scatter", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
let summary = "Scatter operation,";
|
||||
|
||||
let description = [{
|
||||
@ -1669,10 +1619,7 @@ def Tosa_ScatterOp : Tosa_Op<"scatter", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator: resize
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_ResizeOp : Tosa_Op<"resize", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
Pure]> {
|
||||
def Tosa_ResizeOp : Tosa_Op<"resize", [InferShapedTypeOpAdaptor, Pure]> {
|
||||
|
||||
let summary = "Resize operation, supports various resize/upsample modes";
|
||||
|
||||
@ -1898,9 +1845,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Further described in docs/Rationale/RationaleTOSADialect.md .
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_IfOp : Tosa_Op<"cond_if", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
def Tosa_IfOp : Tosa_Op<"cond_if",
|
||||
[InferShapedTypeOpAdaptor,
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
RecursiveMemoryEffects]> {
|
||||
let summary = "Conditional if operator";
|
||||
@ -1933,8 +1879,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if", [
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_WhileOp : Tosa_Op<"while_loop", [
|
||||
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
InferShapedTypeOpAdaptor,
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
RecursiveMemoryEffects]> {
|
||||
let summary = "output = input; While (Cond(output)) {output = Body(output)}";
|
||||
|
@ -262,6 +262,10 @@ template <typename ConcreteType>
|
||||
class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class InferShapedTypeOpAdaptor
|
||||
: public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
|
||||
|
||||
/// Tensor type inference trait that constructs a tensor from the inferred
|
||||
/// shape and elemental types.
|
||||
/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
|
||||
|
@ -222,6 +222,42 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
|
||||
}]
|
||||
>;
|
||||
|
||||
// Convenient trait to define a wrapper to inferReturnTypeComponents that passes
|
||||
// in the Op Adaptor directly
|
||||
class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
|
||||
[
|
||||
// Op implements infer type op interface.
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
|
||||
NativeOpTrait<
|
||||
/*name=*/"InferShapedTypeOpAdaptor",
|
||||
/*traits=*/[],
|
||||
/*extraOpDeclaration=*/[{
|
||||
static ::mlir::LogicalResult
|
||||
inferReturnTypeComponents(::mlir::MLIRContext *context,
|
||||
std::optional<::mlir::Location> location,
|
||||
Adaptor adaptor,
|
||||
::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes);
|
||||
}],
|
||||
/*extraOpDefinition=*/[{
|
||||
::mlir::LogicalResult
|
||||
$cppClass::inferReturnTypeComponents(::mlir::MLIRContext *context,
|
||||
std::optional<::mlir::Location> location,
|
||||
::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes) {
|
||||
$cppClass::Adaptor adaptor(operands, attributes, properties, regions);
|
||||
return $cppClass::inferReturnTypeComponents(context,
|
||||
location, adaptor, inferredReturnShapes);
|
||||
}
|
||||
}]
|
||||
>
|
||||
]>;
|
||||
|
||||
def InferShapedTypeOpAdaptor : InferShapedTypeOpAdaptorBase<[
|
||||
"inferReturnTypeComponents"]>;
|
||||
def InferShapedTypeOpAdaptorWithReify : InferShapedTypeOpAdaptorBase<[
|
||||
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
|
||||
|
||||
// Convenience class grouping together type and shaped type op interfaces for
|
||||
// ops that have tensor return types.
|
||||
class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
|
||||
@ -260,6 +296,44 @@ def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>;
|
||||
def InferTensorTypeWithReify: InferTensorTypeBase<[
|
||||
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
|
||||
|
||||
// Convenience class grouping together type and shaped type op interfaces for
|
||||
// ops that have tensor return types.
|
||||
class InferTensorTypeAdaptorBase<list<string> overridenMethods = []> : TraitList<
|
||||
[
|
||||
// Op implements infer type op interface.
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
// The op will have methods implementing the ShapedType type inference
|
||||
// interface.
|
||||
InferShapedTypeOpAdaptorBase<overridenMethods>,
|
||||
// The op produces tensors and will use the ShapedType type infer interface
|
||||
// along with knowledge that it is producing Tensors to infer the type.
|
||||
NativeOpTrait<
|
||||
/*name=*/"InferTensorType",
|
||||
/*traits=*/[],
|
||||
/*extraOpDeclaration=*/[{}],
|
||||
/*extraOpDefinition=*/[{
|
||||
LogicalResult
|
||||
$cppClass::inferReturnTypes(::mlir::MLIRContext *context,
|
||||
std::optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
SmallVector<ShapedTypeComponents, 2> retComponents;
|
||||
if (failed($cppClass::inferReturnTypeComponents(context, location,
|
||||
operands, attributes, properties, regions,
|
||||
retComponents)))
|
||||
return failure();
|
||||
return ::mlir::detail::inferReturnTensorTypes(retComponents,
|
||||
inferredReturnTypes);
|
||||
}
|
||||
}]
|
||||
>
|
||||
]>;
|
||||
|
||||
def InferTensorTypeAdaptor : InferTensorTypeAdaptorBase<["inferReturnTypeComponents"]>;
|
||||
def InferTensorTypeAdaptorWithReify: InferTensorTypeAdaptorBase<[
|
||||
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
|
||||
|
||||
def ReifyRankedShapedTypeOpInterface :
|
||||
OpInterface<"ReifyRankedShapedTypeOpInterface"> {
|
||||
let description = [{
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
@ -404,12 +405,10 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
|
||||
|
||||
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
ArgMaxOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
auto *prop = properties.as<Properties *>();
|
||||
IntegerAttr axis = prop->axis;
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
IntegerAttr axis = adaptor.getProperties().axis;
|
||||
int32_t axisVal = axis.getValue().getSExtValue();
|
||||
|
||||
if (!inputShape.hasRank()) {
|
||||
@ -431,10 +430,9 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
RFFT2dOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
|
||||
if (!inputShape.hasRank())
|
||||
return failure();
|
||||
@ -458,26 +456,26 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
FFT2dOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
|
||||
inferredReturnShapes.push_back(
|
||||
ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
|
||||
inferredReturnShapes.push_back(
|
||||
ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
ConcatOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
// Infer all dimension sizes by reducing based on inputs.
|
||||
auto *prop = properties.as<Properties *>();
|
||||
int32_t axis = prop->axis.getValue().getSExtValue();
|
||||
const Properties &prop = adaptor.getProperties();
|
||||
int32_t axis = prop.axis.getValue().getSExtValue();
|
||||
llvm::SmallVector<int64_t> outputShape;
|
||||
bool hasRankedInput = false;
|
||||
for (auto operand : operands) {
|
||||
ShapeAdaptor operandShape = operands.getShape(operand);
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
ShapeAdaptor operandShape(operand.getType());
|
||||
if (!operandShape.hasRank())
|
||||
continue;
|
||||
|
||||
@ -501,7 +499,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
|
||||
hasRankedInput = true;
|
||||
}
|
||||
Type inputType =
|
||||
llvm::cast<TensorType>(operands.getType()[0]).getElementType();
|
||||
llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
|
||||
if (!hasRankedInput) {
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
|
||||
return success();
|
||||
@ -509,8 +507,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
|
||||
|
||||
// Determine the dimension size along the concatenation axis.
|
||||
int64_t concatDimSize = 0;
|
||||
for (auto operand : operands) {
|
||||
ShapeAdaptor operandShape = operands.getShape(operand);
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
ShapeAdaptor operandShape(operand.getType());
|
||||
|
||||
// We need to know the length of the concatenation axis of all inputs to
|
||||
// determine the dimension size of the output shape.
|
||||
@ -553,12 +551,11 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
|
||||
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
FullyConnectedOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
ShapeAdaptor weightShape = operands.getShape(1);
|
||||
ShapeAdaptor biasShape = operands.getShape(2);
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
ShapeAdaptor weightShape(adaptor.getWeight().getType());
|
||||
ShapeAdaptor biasShape(adaptor.getBias().getType());
|
||||
|
||||
// All shapes are dynamic.
|
||||
SmallVector<int64_t> outShape;
|
||||
@ -585,11 +582,10 @@ LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
|
||||
|
||||
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
MatMulOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ShapeAdaptor lhsShape = operands.getShape(0);
|
||||
ShapeAdaptor rhsShape = operands.getShape(1);
|
||||
ShapeAdaptor lhsShape(adaptor.getA().getType());
|
||||
ShapeAdaptor rhsShape(adaptor.getB().getType());
|
||||
|
||||
// All shapes are dynamic.
|
||||
SmallVector<int64_t> outShape;
|
||||
@ -612,11 +608,10 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult tosa::PadOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
PadOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
ShapeAdaptor paddingShape = operands.getShape(1);
|
||||
ShapeAdaptor inputShape(adaptor.getInput1().getType());
|
||||
ShapeAdaptor paddingShape(adaptor.getPadding().getType());
|
||||
SmallVector<int64_t> outputShape;
|
||||
|
||||
// If both inputs have unknown shape, we cannot determine the shape of the
|
||||
@ -641,7 +636,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
|
||||
|
||||
DenseIntElementsAttr paddings;
|
||||
// If the paddings value is not a constant, all dimensions must be dynamic.
|
||||
if (!matchPattern(operands[1], m_Constant(&paddings))) {
|
||||
if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
|
||||
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||||
return success();
|
||||
@ -675,22 +670,18 @@ static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
|
||||
|
||||
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
SliceOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents(
|
||||
convertToMlirShape(SliceOpAdaptor(operands, attributes,
|
||||
*properties.as<Properties *>(), regions)
|
||||
.getSize())));
|
||||
inferredReturnShapes.push_back(
|
||||
ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tosa::TableOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
TableOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
|
||||
if (!inputShape.hasRank()) {
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||||
@ -704,13 +695,10 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult tosa::TileOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
TileOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
TileOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
|
||||
regions);
|
||||
ArrayRef<int64_t> multiples = adaptor.getMultiples();
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
ShapeAdaptor inputShape(adaptor.getInput1().getType());
|
||||
SmallVector<int64_t> outputShape;
|
||||
if (!inputShape.hasRank()) {
|
||||
outputShape.resize(multiples.size(), ShapedType::kDynamic);
|
||||
@ -739,13 +727,10 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
|
||||
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
ReshapeOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ReshapeOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
|
||||
regions);
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
Type inputType = getElementTypeOrSelf(operands.getType()[0]);
|
||||
ShapeAdaptor inputShape(adaptor.getInput1().getType());
|
||||
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
|
||||
llvm::SmallVector<int64_t> newShapeValue =
|
||||
convertToMlirShape(adaptor.getNewShape());
|
||||
|
||||
@ -814,11 +799,10 @@ LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
|
||||
|
||||
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
TransposeOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
ShapeAdaptor permsShape = operands.getShape(1);
|
||||
ShapeAdaptor inputShape(adaptor.getInput1().getType());
|
||||
ShapeAdaptor permsShape(adaptor.getPerms().getType());
|
||||
|
||||
// If input rank and permutation length is unknown, the output rank is
|
||||
// unknown.
|
||||
@ -869,7 +853,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
|
||||
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
|
||||
// If the permuations are a constant we can directly determine the output
|
||||
// shape.
|
||||
if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
|
||||
DenseIntElementsAttr attr;
|
||||
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
|
||||
attr.getType().getRank() == 1) {
|
||||
ShapeAdaptor permShape = attr;
|
||||
outputShape.reserve(inputShape.getRank());
|
||||
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
|
||||
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
|
||||
@ -882,19 +869,18 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
GatherOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<int64_t> outputShape;
|
||||
outputShape.resize(3, ShapedType::kDynamic);
|
||||
|
||||
ShapeAdaptor valuesShape = operands.getShape(0);
|
||||
ShapeAdaptor valuesShape(adaptor.getValues().getType());
|
||||
if (valuesShape.hasRank()) {
|
||||
outputShape[0] = valuesShape.getDimSize(0);
|
||||
outputShape[2] = valuesShape.getDimSize(2);
|
||||
}
|
||||
|
||||
ShapeAdaptor indicesShape = operands.getShape(1);
|
||||
ShapeAdaptor indicesShape(adaptor.getIndices().getType());
|
||||
if (indicesShape.hasRank()) {
|
||||
if (outputShape[0] == ShapedType::kDynamic)
|
||||
outputShape[0] = indicesShape.getDimSize(0);
|
||||
@ -908,15 +894,12 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
ResizeOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ResizeOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
|
||||
regions);
|
||||
llvm::SmallVector<int64_t, 4> outputShape;
|
||||
outputShape.resize(4, ShapedType::kDynamic);
|
||||
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
if (!inputShape.hasRank())
|
||||
return failure();
|
||||
|
||||
@ -950,26 +933,25 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
ScatterOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<int64_t> outputShape;
|
||||
outputShape.resize(3, ShapedType::kDynamic);
|
||||
|
||||
ShapeAdaptor valuesInShape = operands.getShape(0);
|
||||
ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
|
||||
if (valuesInShape.hasRank()) {
|
||||
outputShape[0] = valuesInShape.getDimSize(0);
|
||||
outputShape[1] = valuesInShape.getDimSize(1);
|
||||
outputShape[2] = valuesInShape.getDimSize(2);
|
||||
}
|
||||
|
||||
ShapeAdaptor indicesShape = operands.getShape(1);
|
||||
ShapeAdaptor indicesShape(adaptor.getIndices().getType());
|
||||
if (indicesShape.hasRank()) {
|
||||
if (outputShape[0] == ShapedType::kDynamic)
|
||||
outputShape[0] = indicesShape.getDimSize(0);
|
||||
}
|
||||
|
||||
ShapeAdaptor inputShape = operands.getShape(2);
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
if (inputShape.hasRank()) {
|
||||
if (outputShape[0] == ShapedType::kDynamic)
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
@ -1009,13 +991,13 @@ static LogicalResult ReduceInferReturnTypes(
|
||||
#define REDUCE_SHAPE_INFER(OP) \
|
||||
LogicalResult OP::inferReturnTypeComponents( \
|
||||
MLIRContext *context, ::std::optional<Location> location, \
|
||||
ValueShapeRange operands, DictionaryAttr attributes, \
|
||||
OpaqueProperties properties, RegionRange regions, \
|
||||
OP::Adaptor adaptor, \
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
|
||||
Type inputType = \
|
||||
llvm::cast<TensorType>(operands.getType()[0]).getElementType(); \
|
||||
return ReduceInferReturnTypes(operands.getShape(0), inputType, \
|
||||
properties.as<Properties *>()->axis, \
|
||||
llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType()); \
|
||||
const Properties &prop = adaptor.getProperties(); \
|
||||
return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
|
||||
inferredReturnShapes); \
|
||||
} \
|
||||
COMPATIBLE_RETURN_TYPES(OP)
|
||||
@ -1092,10 +1074,9 @@ NARY_SHAPE_INFER(tosa::SigmoidOp)
|
||||
#undef PRED_SHAPE_INFER
|
||||
|
||||
static LogicalResult poolingInferReturnTypes(
|
||||
const ValueShapeRange &operands, DictionaryAttr attributes,
|
||||
ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride, ArrayRef<int64_t> pad,
|
||||
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
|
||||
ArrayRef<int64_t> pad,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
llvm::SmallVector<int64_t> outputShape;
|
||||
outputShape.resize(4, ShapedType::kDynamic);
|
||||
|
||||
@ -1128,12 +1109,9 @@ static LogicalResult poolingInferReturnTypes(
|
||||
|
||||
LogicalResult Conv2DOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
Conv2DOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
|
||||
Conv2DOp::Adaptor adaptor(operands, attributes,
|
||||
*properties.as<Properties *>(), regions);
|
||||
|
||||
int64_t inputWidth = ShapedType::kDynamic;
|
||||
int64_t inputHeight = ShapedType::kDynamic;
|
||||
@ -1142,7 +1120,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
|
||||
|
||||
// Input shape describes input width/height and batch.
|
||||
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
inputHeight = inputShape.getDimSize(1);
|
||||
@ -1150,7 +1128,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
|
||||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
||||
ShapeAdaptor weightShape(adaptor.getWeight().getType());
|
||||
if (weightShape.hasRank()) {
|
||||
outputShape[3] = weightShape.getDimSize(0);
|
||||
weightHeight = weightShape.getDimSize(1);
|
||||
@ -1158,7 +1136,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
|
||||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
||||
ShapeAdaptor biasShape(adaptor.getBias().getType());
|
||||
if (biasShape.hasRank()) {
|
||||
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||||
? biasShape.getDimSize(0)
|
||||
@ -1193,12 +1171,9 @@ LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
|
||||
|
||||
LogicalResult Conv3DOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
Conv3DOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
|
||||
Conv3DOp::Adaptor adaptor(operands, attributes,
|
||||
*properties.as<Properties *>(), regions);
|
||||
|
||||
int64_t inputWidth = ShapedType::kDynamic;
|
||||
int64_t inputHeight = ShapedType::kDynamic;
|
||||
@ -1209,7 +1184,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
||||
int64_t weightDepth = ShapedType::kDynamic;
|
||||
|
||||
// Input shape describes input width/height and batch.
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
inputDepth = inputShape.getDimSize(1);
|
||||
@ -1218,7 +1193,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
||||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
||||
ShapeAdaptor weightShape(adaptor.getWeight().getType());
|
||||
if (weightShape.hasRank()) {
|
||||
outputShape[4] = weightShape.getDimSize(0);
|
||||
weightDepth = weightShape.getDimSize(1);
|
||||
@ -1227,7 +1202,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
||||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
||||
ShapeAdaptor biasShape(adaptor.getBias().getType());
|
||||
if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
|
||||
outputShape[4] = biasShape.getDimSize(0);
|
||||
}
|
||||
@ -1268,32 +1243,29 @@ LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
|
||||
|
||||
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
AvgPool2dOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
Properties &prop = *properties.as<Properties *>();
|
||||
return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
|
||||
prop.pad, inferredReturnShapes);
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
const Properties &prop = adaptor.getProperties();
|
||||
return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
|
||||
inferredReturnShapes);
|
||||
}
|
||||
|
||||
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
MaxPool2dOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
Properties &prop = *properties.as<Properties *>();
|
||||
return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
|
||||
prop.pad, inferredReturnShapes);
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
const Properties &prop = adaptor.getProperties();
|
||||
return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
|
||||
inferredReturnShapes);
|
||||
}
|
||||
|
||||
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
DepthwiseConv2DOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
|
||||
DepthwiseConv2DOp::Adaptor adaptor(operands, attributes,
|
||||
*properties.as<Properties *>(), regions);
|
||||
|
||||
int64_t inputWidth = ShapedType::kDynamic;
|
||||
int64_t inputHeight = ShapedType::kDynamic;
|
||||
@ -1304,7 +1276,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
||||
int64_t depthChannels = ShapedType::kDynamic;
|
||||
|
||||
// Input shape describes input width/height and batch.
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
inputHeight = inputShape.getDimSize(1);
|
||||
@ -1313,7 +1285,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
||||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
||||
ShapeAdaptor weightShape(adaptor.getWeight().getType());
|
||||
if (weightShape.hasRank()) {
|
||||
weightHeight = weightShape.getDimSize(0);
|
||||
weightWidth = weightShape.getDimSize(1);
|
||||
@ -1331,7 +1303,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
||||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
||||
ShapeAdaptor biasShape(adaptor.getBias().getType());
|
||||
if (biasShape.hasRank()) {
|
||||
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||||
? biasShape.getDimSize(0)
|
||||
@ -1366,11 +1338,8 @@ LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
|
||||
|
||||
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
TransposeConv2DOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
TransposeConv2DOp::Adaptor adaptor(operands, attributes,
|
||||
*properties.as<Properties *>(), regions);
|
||||
// outputShape is mutable.
|
||||
llvm::SmallVector<int64_t> outputShape =
|
||||
convertToMlirShape(adaptor.getOutShape());
|
||||
@ -1381,7 +1350,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
||||
int64_t weightHeight = ShapedType::kDynamic;
|
||||
|
||||
// Input shape describes input width/height and batch.
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = ShapedType::isDynamic(outputShape[0])
|
||||
? inputShape.getDimSize(0)
|
||||
@ -1391,7 +1360,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
||||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter());
|
||||
ShapeAdaptor weightShape(adaptor.getFilter().getType());
|
||||
if (weightShape.hasRank()) {
|
||||
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||||
? weightShape.getDimSize(0)
|
||||
@ -1401,7 +1370,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
||||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getInput());
|
||||
ShapeAdaptor biasShape(adaptor.getInput().getType());
|
||||
if (biasShape.hasRank()) {
|
||||
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||||
? biasShape.getDimSize(0)
|
||||
@ -1433,11 +1402,10 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult IfOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
IfOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<tosa::YieldOp> yieldOps;
|
||||
for (Region *region : regions) {
|
||||
for (Region *region : adaptor.getRegions()) {
|
||||
for (auto &block : *region)
|
||||
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
|
||||
yieldOps.push_back(returnOp);
|
||||
@ -1478,11 +1446,10 @@ LogicalResult IfOp::inferReturnTypeComponents(
|
||||
|
||||
LogicalResult WhileOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::std::optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
WhileOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<tosa::YieldOp> yieldOps;
|
||||
for (auto &block : *regions[1])
|
||||
for (auto &block : adaptor.getBody())
|
||||
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
|
||||
yieldOps.push_back(returnOp);
|
||||
|
||||
|
@ -1437,9 +1437,8 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
|
||||
// Create return type consisting of the last element of the first operand.
|
||||
auto operandType = operands.front().getType();
|
||||
auto sval = dyn_cast<ShapedType>(operandType);
|
||||
if (!sval) {
|
||||
if (!sval)
|
||||
return emitOptionalError(location, "only shaped type operands allowed");
|
||||
}
|
||||
int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
|
||||
auto type = IntegerType::get(context, 17);
|
||||
|
||||
@ -1458,6 +1457,35 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
// Create return type consisting of the last element of the first operand.
|
||||
auto operandType = adaptor.getOperand1().getType();
|
||||
auto sval = dyn_cast<ShapedType>(operandType);
|
||||
if (!sval)
|
||||
return emitOptionalError(location, "only shaped type operands allowed");
|
||||
int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
|
||||
auto type = IntegerType::get(context, 17);
|
||||
|
||||
Attribute encoding;
|
||||
if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
|
||||
encoding = rankedTy.getEncoding();
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
|
||||
OpBuilder &builder, ValueRange operands,
|
||||
llvm::SmallVectorImpl<Value> &shapes) {
|
||||
shapes = SmallVector<Value, 1>{
|
||||
builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
|
||||
OpBuilder &builder, ValueRange operands,
|
||||
llvm::SmallVectorImpl<Value> &shapes) {
|
||||
|
@ -780,6 +780,13 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def OpWithShapedTypeInferTypeAdaptorInterfaceOp :
|
||||
TEST_Op<"op_with_shaped_type_infer_type_adaptor_if",
|
||||
[InferTensorTypeAdaptorWithReify]> {
|
||||
let arguments = (ins AnyTensor:$operand1, AnyTensor:$operand2);
|
||||
let results = (outs AnyTensor:$result);
|
||||
}
|
||||
|
||||
def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
|
||||
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapes"]>]> {
|
||||
|
Loading…
Reference in New Issue
Block a user