mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-01 18:12:44 +00:00
[MLIR] Add canoncalization for shape.is_broadcastable
Canonicalize `is_broadcastable` to constant true if fewer than 2 unique shape operands. Eliminate redundant operands, otherwise. Differential Revision: https://reviews.llvm.org/D98361
This commit is contained in:
parent
2224221fb3
commit
b975e3b5aa
@ -277,9 +277,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
|
||||
};
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
}
|
||||
|
||||
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
|
||||
|
@ -779,6 +779,44 @@ static LogicalResult verify(IsBroadcastableOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct IsBroadcastableCanonicalizationPattern
|
||||
: public OpRewritePattern<IsBroadcastableOp> {
|
||||
using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IsBroadcastableOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find unique operands.
|
||||
SmallVector<Value, 2> unique;
|
||||
for (Value v : op.getOperands()) {
|
||||
if (!llvm::is_contained(unique, v))
|
||||
unique.push_back(v);
|
||||
}
|
||||
|
||||
// Can always broadcast fewer than two shapes.
|
||||
if (unique.size() < 2) {
|
||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
|
||||
rewriter.getBoolAttr(true));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Reduce op to equivalent with unique operands.
|
||||
if (unique.size() < op.getNumOperands()) {
|
||||
rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
|
||||
unique);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void IsBroadcastableOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RankOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1069,3 +1069,28 @@ func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xind
|
||||
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
|
||||
return %1 : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @is_broadcastable_on_same_shape
|
||||
func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 {
|
||||
// CHECK-NOT: is_broadcastable
|
||||
// CHECK: %[[RES:.*]] = constant true
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = shape.is_broadcastable %shape, %shape, %shape
|
||||
: !shape.shape, !shape.shape, !shape.shape
|
||||
return %0 : i1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @is_broadcastable_on_duplicate_shapes
|
||||
// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
|
||||
func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
|
||||
-> i1 {
|
||||
// CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
|
||||
!shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
|
||||
return %0 : i1
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user