mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-01 18:12:44 +00:00
[MLIR][TOSA] Add lowering from TOSA to Linalg for math-based and elementwise ops
This patch adds lowering to Linalg for the following TOSA ops: negate, rsqrt, mul, select, clamp and reluN and includes support for signless integer and floating point types Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D96924
This commit is contained in:
parent
eb2eeeb76f
commit
25b4a6a7f0
@ -24,6 +24,28 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
|
||||
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static mlir::ConstantOp
|
||||
createConstFromIntAttribute(Operation *op, std::string attrName,
|
||||
Type requiredAttrType, PatternRewriter &rewriter) {
|
||||
auto castedN = static_cast<T>(
|
||||
op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
|
||||
return rewriter.create<mlir::ConstantOp>(
|
||||
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
|
||||
}
|
||||
|
||||
template <typename T, typename P>
|
||||
static mlir::SelectOp clampHelper(Operation *op, ValueRange args,
|
||||
mlir::ConstantOp min, mlir::ConstantOp max,
|
||||
P pred, PatternRewriter &rewriter) {
|
||||
Location loc = op->getLoc();
|
||||
auto smallerThanMin = rewriter.create<T>(loc, pred, args[0], min);
|
||||
auto minOrArg =
|
||||
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, args[0]);
|
||||
auto largerThanMax = rewriter.create<T>(loc, pred, max, args[0]);
|
||||
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
|
||||
}
|
||||
|
||||
static Value
|
||||
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
||||
ArrayRef<Type> resultTypes,
|
||||
@ -43,6 +65,42 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
||||
if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
|
||||
return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::SubOp
|
||||
if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
|
||||
|
||||
if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
|
||||
return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::MulOp
|
||||
if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
|
||||
if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
|
||||
(void)rewriter.notifyMatchFailure(op,
|
||||
"Cannot have shift value for float");
|
||||
return nullptr;
|
||||
}
|
||||
return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
|
||||
}
|
||||
|
||||
if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
|
||||
auto mul =
|
||||
rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
|
||||
auto constant =
|
||||
rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
|
||||
return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
|
||||
constant);
|
||||
}
|
||||
|
||||
// tosa::NegateOp
|
||||
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>()) {
|
||||
auto constant =
|
||||
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, -1));
|
||||
return rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], constant);
|
||||
}
|
||||
|
||||
if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::BitwiseAndOp
|
||||
if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
|
||||
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
|
||||
@ -67,6 +125,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
||||
if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::RsqrtOp
|
||||
if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::LogOp
|
||||
if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
|
||||
@ -75,13 +137,6 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
||||
if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::SubOp
|
||||
if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
|
||||
|
||||
if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
|
||||
return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::TanhOp
|
||||
if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
|
||||
@ -104,6 +159,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
||||
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
|
||||
args[1]);
|
||||
|
||||
// tosa::SelectOp
|
||||
if (isa<tosa::SelectOp>(op)) {
|
||||
elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
|
||||
if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
|
||||
return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
|
||||
}
|
||||
|
||||
// tosa::MaximumOp
|
||||
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
|
||||
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
|
||||
@ -138,6 +200,44 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
||||
if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
|
||||
return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::ClampOp
|
||||
if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
|
||||
auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
|
||||
op->getAttr("min_fp"));
|
||||
auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
|
||||
op->getAttr("max_fp"));
|
||||
return clampHelper<mlir::CmpFOp>(op, args, min, max, CmpFPredicate::OLT,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
|
||||
auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
|
||||
rewriter);
|
||||
auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
|
||||
rewriter);
|
||||
return clampHelper<mlir::CmpIOp>(op, args, min, max, CmpIPredicate::slt,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
// tosa::ReluNOp
|
||||
if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
|
||||
auto zero =
|
||||
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
|
||||
auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
|
||||
op->getAttr("max_fp"));
|
||||
return clampHelper<mlir::CmpFOp>(op, args, zero, n, CmpFPredicate::OLT,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
|
||||
auto zero =
|
||||
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
|
||||
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
|
||||
rewriter);
|
||||
return clampHelper<mlir::CmpIOp>(op, args, zero, n, CmpIPredicate::slt,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
op, "unhandled op for linalg body calculation for elementwise op");
|
||||
return nullptr;
|
||||
@ -245,16 +345,19 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<
|
||||
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
|
||||
PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::LogOp>,
|
||||
PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
|
||||
PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
|
||||
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
|
||||
PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::RsqrtOp>,
|
||||
PointwiseConverter<tosa::LogOp>, PointwiseConverter<tosa::ExpOp>,
|
||||
PointwiseConverter<tosa::AbsOp>, PointwiseConverter<tosa::TanhOp>,
|
||||
PointwiseConverter<tosa::BitwiseAndOp>,
|
||||
PointwiseConverter<tosa::BitwiseOrOp>,
|
||||
PointwiseConverter<tosa::BitwiseXorOp>,
|
||||
PointwiseConverter<tosa::LogicalLeftShiftOp>,
|
||||
PointwiseConverter<tosa::LogicalRightShiftOp>,
|
||||
PointwiseConverter<tosa::GreaterOp>,
|
||||
PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
|
||||
PointwiseConverter<tosa::GreaterEqualOp>,
|
||||
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
|
||||
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>>(
|
||||
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
|
||||
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>>(
|
||||
context);
|
||||
}
|
||||
|
@ -116,43 +116,69 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
|
||||
// CHECK: subf
|
||||
%3 = "tosa.sub"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
%4 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: negf
|
||||
%5 = "tosa.negate"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: pow
|
||||
%4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%6 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: rsqrt
|
||||
%7 = "tosa.rsqrt"(%1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: log
|
||||
%5 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%8 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
%6 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%9 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpf
|
||||
%7 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
|
||||
%10 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpf
|
||||
%8 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
|
||||
%11 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: select
|
||||
%12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpf
|
||||
// CHECK: select
|
||||
%9 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%13 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpf
|
||||
// CHECK: select
|
||||
%10 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%14 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: ceil
|
||||
%11 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%15 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: floor
|
||||
%12 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%16 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpf
|
||||
// CHECK: select
|
||||
%17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpf
|
||||
// CHECK: select
|
||||
%18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
return
|
||||
}
|
||||
@ -169,44 +195,65 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
|
||||
// CHECK: subi
|
||||
%1 = "tosa.sub"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: muli
|
||||
%2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: muli
|
||||
%3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: and
|
||||
%2 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: or
|
||||
%3 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: xor
|
||||
%4 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: shift_left
|
||||
%5 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: shift_right_unsigned
|
||||
%6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpi
|
||||
%7 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
|
||||
%9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpi
|
||||
%8 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
|
||||
%10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: select
|
||||
%11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpi
|
||||
// CHECK: select
|
||||
%9 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpi
|
||||
// CHECK: select
|
||||
%10 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpi
|
||||
// CHECK: select
|
||||
%14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpi
|
||||
// CHECK: select
|
||||
%15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user