[MLIR] Add canoncalization for shape.is_broadcastable

Canonicalize `is_broadcastable` to constant true if fewer than 2 unique shape
operands. Eliminate redundant operands, otherwise.

Differential Revision: https://reviews.llvm.org/D98361
This commit is contained in:
Frederik Gossen 2021-03-11 10:09:26 +01:00
parent 2224221fb3
commit b975e3b5aa
3 changed files with 65 additions and 1 deletions

View File

@ -277,9 +277,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
};
}];
let hasCanonicalizer = 1;
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
let verifier = [{ return ::verify(*this); }];
}
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {

View File

@ -779,6 +779,44 @@ static LogicalResult verify(IsBroadcastableOp op) {
return success();
}
namespace {
struct IsBroadcastableCanonicalizationPattern
: public OpRewritePattern<IsBroadcastableOp> {
using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IsBroadcastableOp op,
PatternRewriter &rewriter) const override {
// Find unique operands.
SmallVector<Value, 2> unique;
for (Value v : op.getOperands()) {
if (!llvm::is_contained(unique, v))
unique.push_back(v);
}
// Can always broadcast fewer than two shapes.
if (unique.size() < 2) {
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
rewriter.getBoolAttr(true));
return success();
}
// Reduce op to equivalent with unique operands.
if (unique.size() < op.getNumOperands()) {
rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
unique);
return success();
}
return failure();
}
};
} // namespace
void IsBroadcastableOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
}
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//

View File

@ -1069,3 +1069,28 @@ func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xind
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
return %1 : tensor<?xindex>
}
// -----
// CHECK-LABEL: @is_broadcastable_on_same_shape
func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 {
// CHECK-NOT: is_broadcastable
// CHECK: %[[RES:.*]] = constant true
// CHECK: return %[[RES]]
%0 = shape.is_broadcastable %shape, %shape, %shape
: !shape.shape, !shape.shape, !shape.shape
return %0 : i1
}
// -----
// CHECK-LABEL: @is_broadcastable_on_duplicate_shapes
// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
-> i1 {
// CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
// CHECK: return %[[RES]]
%0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
!shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
return %0 : i1
}