From 11175b55072418e148cec4bf0a6e858b2873f58f Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 8 Dec 2022 19:06:19 +0100 Subject: [PATCH] [mlir][linalg] Print broadcast, map, reduce, transpose ins/outs on one line. Differential Revision: https://reviews.llvm.org/D139650 --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 21 +++------ .../Dialect/Linalg/one-shot-bufferize.mlir | 9 ++-- mlir/test/Dialect/Linalg/roundtrip.mlir | 43 ++++++++----------- 3 files changed, 26 insertions(+), 47 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 98b1406d9848..8b0540e10d01 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -178,12 +178,10 @@ static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs) { if (!inputs.empty()) { - p.printNewline(); - p << "ins(" << inputs << " : " << inputs.getTypes() << ")"; + p << " ins(" << inputs << " : " << inputs.getTypes() << ")"; } if (!outputs.empty()) { - p.printNewline(); - p << "outs(" << outputs << " : " << outputs.getTypes() << ")"; + p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; } } //===----------------------------------------------------------------------===// @@ -1041,12 +1039,12 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { } void MapOp::print(OpAsmPrinter &p) { - p.increaseIndent(); printCommonStructuredOpPartsWithNewLine( p, SmallVector(getDpsInputOperands()), SmallVector(getDpsInitOperands())); p.printOptionalAttrDict((*this)->getAttrs()); + p.increaseIndent(); p.printNewline(); p << "("; llvm::interleaveComma(getMapper().getArguments(), p, @@ -1210,19 +1208,18 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef attributeValue) { - p << attributeName << " = [" << attributeValue << "] "; + p << ' ' << attributeName << " = [" << attributeValue << "] "; } void ReduceOp::print(OpAsmPrinter &p) { - p.increaseIndent(); printCommonStructuredOpPartsWithNewLine( p, SmallVector(getDpsInputOperands()), SmallVector(getDpsInitOperands())); - p.printNewline(); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); + p.increaseIndent(); p.printNewline(); p << "("; llvm::interleaveComma(getCombiner().getArguments(), p, @@ -1379,15 +1376,11 @@ void TransposeOp::getAsmResultNames( } void TransposeOp::print(OpAsmPrinter &p) { - p.increaseIndent(); printCommonStructuredOpPartsWithNewLine( p, SmallVector(getDpsInputOperands()), SmallVector(getDpsInitOperands())); - p.printNewline(); - printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); - p.decreaseIndent(); } LogicalResult TransposeOp::verify() { @@ -1498,15 +1491,11 @@ void BroadcastOp::getAsmResultNames( } void BroadcastOp::print(OpAsmPrinter &p) { - p.increaseIndent(); printCommonStructuredOpPartsWithNewLine( p, SmallVector(getDpsInputOperands()), SmallVector(getDpsInitOperands())); - p.printNewline(); - printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); - p.decreaseIndent(); } LogicalResult BroadcastOp::verify() { diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir index ef2d218db643..d418a92775cf 100644 --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -340,8 +340,7 @@ func.func @op_is_reading_but_following_ops_are_not( // CHECK-SAME: %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32 func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { - // CHECK: linalg.map - // CHECK-NEXT: ins(%[[LHS]], %[[RHS]] : memref<64xf32 + // CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : memref<64xf32 %add = linalg.map ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) @@ -358,8 +357,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, // CHECK-SAME: %[[INPUT:.*]]: memref<16x32x64xf32 func.func @reduce(%input: tensor<16x32x64xf32>, %init: tensor<16x64xf32>) -> tensor<16x64xf32> { - // CHECK: linalg.reduce - // CHECK-NEXT: ins(%[[INPUT]] : memref<16x32x64xf32 + // CHECK: linalg.reduce ins(%[[INPUT]] : memref<16x32x64xf32 %reduce = linalg.reduce ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<16x64xf32>) @@ -377,8 +375,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>, // CHECK-SAME: %[[ARG0:.*]]: memref<16x32x64xf32 func.func @transpose(%input: tensor<16x32x64xf32>, %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { - // CHECK: linalg.transpose - // CHECK-NEXT: ins(%[[ARG0]] : memref<16x32x64xf32 + // CHECK: linalg.transpose ins(%[[ARG0]] : memref<16x32x64xf32 %transpose = linalg.transpose ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<32x64x16xf32>) diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 8f0c83fe202e..b1a614fb768a 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -336,8 +336,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> { func.return %add : tensor<64xf32> } // CHECK-LABEL: func @map_no_inputs -// CHECK: linalg.map -// CHECK-NEXT: outs +// CHECK: linalg.map outs // CHECK-NEXT: () { // CHECK-NEXT: arith.constant // CHECK-NEXT: linalg.yield @@ -357,9 +356,8 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, func.return %add : tensor<64xf32> } // CHECK-LABEL: func @map_binary -// CHECK: linalg.map -// CHECK-NEXT: ins -// CHECK-NEXT: outs +// CHECK: linalg.map ins +// CHECK-SAME: outs // CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { // CHECK-NEXT: arith.addf // CHECK-NEXT: linalg.yield @@ -426,10 +424,9 @@ func.func @reduce(%input: tensor<16x32x64xf32>, func.return %reduce : tensor<16x64xf32> } // CHECK-LABEL: func @reduce -// CHECK: linalg.reduce -// CHECK-NEXT: ins -// CHECK-NEXT: outs -// CHECK-NEXT: dimensions = [1] +// CHECK: linalg.reduce ins +// CHECK-SAME: outs +// CHECK-SAME: dimensions = [1] // CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { // CHECK-NEXT: arith.addf // CHECK-NEXT: linalg.yield @@ -501,10 +498,9 @@ func.func @transpose(%input: tensor<16x32x64xf32>, func.return %transpose : tensor<32x64x16xf32> } // CHECK-LABEL: func @transpose -// CHECK: linalg.transpose -// CHECK-NEXT: ins -// CHECK-NEXT: outs -// CHECK-NEXT: permutation +// CHECK: linalg.transpose ins +// CHECK-SAME: outs +// CHECK-SAME: permutation // ----- @@ -529,10 +525,9 @@ func.func @broadcast_static_sizes(%input: tensor<8x32xf32>, func.return %bcast : tensor<8x16x32xf32> } // CHECK-LABEL: func @broadcast_static_sizes -// CHECK: linalg.broadcast -// CHECK-NEXT: ins -// CHECK-NEXT: outs -// CHECK-NEXT: dimensions +// CHECK: linalg.broadcast ins +// CHECK-SAME: outs +// CHECK-SAME: dimensions // ----- @@ -546,10 +541,9 @@ func.func @broadcast_with_dynamic_sizes( func.return %bcast : tensor<8x16x?xf32> } // CHECK-LABEL: func @broadcast_with_dynamic_sizes -// CHECK: linalg.broadcast -// CHECK-NEXT: ins -// CHECK-NEXT: outs -// CHECK-NEXT: dimensions +// CHECK: linalg.broadcast ins +// CHECK-SAME: outs +// CHECK-SAME: dimensions // ----- @@ -563,7 +557,6 @@ func.func @broadcast_memref(%input: memref<8x32xf32>, } // CHECK-LABEL: func @broadcast_memref -// CHECK: linalg.broadcast -// CHECK-NEXT: ins -// CHECK-NEXT: outs -// CHECK-NEXT: dimensions +// CHECK: linalg.broadcast ins +// CHECK-SAME: outs +// CHECK-SAME: dimensions