mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-15 12:09:51 +00:00
[mlir][vector] Implement Workaround Lowerings for Masked fm**imum
Reductions
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. Within LLVM, there are no masked reduction counterparts for vector reductions such as `fmaximum` and `fminimum`. More information can be found here: https://github.com/llvm/llvm-project/issues/64940#issuecomment-1690694156. To address this issue in MLIR, where we need to generate appropriate lowerings for these cases, we employ regular non-masked intrinsics. However, we modify the input vector using the `arith.select` operation to effectively deactivate undesired elements using a "neutral mask value". The neutral mask value is the smallest possible value for the `fmaximum` reduction and the largest possible value for the `fminimum` reduction. Depends on D158618 Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D158773
This commit is contained in:
parent
709b27427b
commit
8f5d519458
@ -15,13 +15,17 @@
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <optional>
|
||||
|
||||
@ -603,6 +607,51 @@ createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Reduction neutral classes for overloading
|
||||
class MaskNeutralFMaximum {};
|
||||
class MaskNeutralFMinimum {};
|
||||
|
||||
/// Get the mask neutral floating point maximum value
|
||||
static llvm::APFloat
|
||||
getMaskNeutralValue(MaskNeutralFMaximum,
|
||||
const llvm::fltSemantics &floatSemantics) {
|
||||
return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
|
||||
}
|
||||
/// Get the mask neutral floating point minimum value
|
||||
static llvm::APFloat
|
||||
getMaskNeutralValue(MaskNeutralFMinimum,
|
||||
const llvm::fltSemantics &floatSemantics) {
|
||||
return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
|
||||
}
|
||||
|
||||
/// Create the mask neutral floating point MLIR vector constant
|
||||
template <typename MaskNeutral>
|
||||
static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type llvmType,
|
||||
Type vectorType) {
|
||||
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
|
||||
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
|
||||
auto denseValue =
|
||||
DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
|
||||
}
|
||||
|
||||
/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
|
||||
/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
|
||||
/// `fmaximum`/`fminimum`.
|
||||
/// More information: https://github.com/llvm/llvm-project/issues/64940
|
||||
template <class LLVMRedIntrinOp, class MaskNeutral>
|
||||
static Value lowerMaskedReductionWithRegular(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
|
||||
Value vectorOperand, Value accumulator, Value mask) {
|
||||
const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
|
||||
rewriter, loc, llvmType, vectorOperand.getType());
|
||||
const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
|
||||
loc, mask, vectorOperand, vectorMaskNeutral);
|
||||
return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
|
||||
rewriter, loc, llvmType, selectedVectorByMask, accumulator);
|
||||
}
|
||||
|
||||
/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
|
||||
/// a start value. This start value format spans across fp reductions without
|
||||
/// mask and all the masked reduction intrinsics.
|
||||
@ -903,10 +952,16 @@ public:
|
||||
ReductionNeutralFPMin>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
default:
|
||||
return rewriter.notifyMatchFailure(
|
||||
maskOp,
|
||||
"lowering to LLVM is not implemented for this masked operation");
|
||||
case CombiningKind::MAXIMUMF:
|
||||
result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
|
||||
MaskNeutralFMaximum>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
case CombiningKind::MINIMUMF:
|
||||
result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
|
||||
MaskNeutralFMinimum>(
|
||||
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
||||
break;
|
||||
}
|
||||
|
||||
// Replace `vector.mask` operation altogether.
|
||||
|
@ -101,6 +101,36 @@ func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_maximumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
|
||||
%0 = vector.mask %mask { vector.reduction <maximumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_maximumf_f32(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
|
||||
// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<-1.401300e-45> : vector<16xf32>) : vector<16xf32>
|
||||
// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
|
||||
// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fmaximum(%[[MASKED]]) : (vector<16xf32>) -> f32
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_minimumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
|
||||
%0 = vector.mask %mask { vector.reduction <minimumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_reduce_minimumf_f32(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
|
||||
// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<3.40282347E+38> : vector<16xf32>) : vector<16xf32>
|
||||
// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
|
||||
// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fminimum(%[[MASKED]]) : (vector<16xf32>) -> f32
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
||||
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
|
||||
return %0 : i8
|
||||
|
Loading…
Reference in New Issue
Block a user