diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 6916fa78abbb..7212700d641e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -116,7 +116,9 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> { OpBuilderDAG<(ins "ArrayRef":$staticShape, "Type":$elementType), [{ build($_builder, $_state, ValueRange{}, staticShape, elementType); - }]> + }]>, + OpBuilderDAG<(ins "ArrayRef":$sizes, "Type":$elementType, + CArg<"ArrayRef", "{}">:$attrs)> ]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 989a164007d5..42a7900f23bb 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -87,6 +87,24 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p, template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); +/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if +/// it is a Value or into `staticVec` if it is an IntegerAttr. +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +static void dispatchIndexOpFoldResult(OpFoldResult ofr, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel) { + if (auto v = ofr.dyn_cast()) { + dynamicVec.push_back(v); + staticVec.push_back(sentinel); + return; + } + APInt apInt = ofr.dyn_cast().cast().getValue(); + staticVec.push_back(apInt.getSExtValue()); +} + /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast) -> someop @@ -539,6 +557,24 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// +void InitTensorOp::build(OpBuilder &b, OperationState &result, + ArrayRef sizes, Type elementType, + ArrayRef attrs) { + unsigned rank = sizes.size(); + SmallVector dynamicSizes; + SmallVector staticSizes; + for (unsigned i = 0; i < rank; ++i) { + // staticLow and staticHigh have full information of the padding config. + // This will grow staticLow and staticHigh with 1 value. If the config is + // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 + // value as well. + dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes, + ShapedType::kDynamicSize); + } + auto resultType = RankedTensorType ::get(staticSizes, elementType); + build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); + result.addAttributes(attrs); +} static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); @@ -857,24 +893,6 @@ RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, return RankedTensorType::get(resultShape, sourceType.getElementType()); } -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef staticLow, ArrayRef staticHigh, ValueRange low,