[mlir][vector] Implement Workaround Lowerings for Masked fm**imum Reductions

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Within LLVM, there are no masked reduction counterparts for vector reductions such as `fmaximum` and `fminimum`.
More information can be found here: https://github.com/llvm/llvm-project/issues/64940#issuecomment-1690694156.

To address this issue in MLIR, where we need to generate appropriate lowerings for these cases, we employ regular non-masked intrinsics.
However, we modify the input vector using the `arith.select` operation to effectively deactivate undesired elements using a "neutral mask value".
The neutral mask value is the smallest possible value for the `fmaximum` reduction and the largest possible value for the `fminimum` reduction.

Depends on D158618

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D158773
This commit is contained in:
Daniil Dudkin 2023-09-13 22:18:36 +00:00 committed by Diego Caballero
parent 709b27427b
commit 8f5d519458
2 changed files with 89 additions and 4 deletions

View File

@ -15,13 +15,17 @@
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/Casting.h"
#include <optional>
@ -603,6 +607,51 @@ createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
return result;
}
/// Reduction neutral classes for overloading
class MaskNeutralFMaximum {};
class MaskNeutralFMinimum {};
/// Get the mask neutral floating point maximum value
static llvm::APFloat
getMaskNeutralValue(MaskNeutralFMaximum,
const llvm::fltSemantics &floatSemantics) {
return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
}
/// Get the mask neutral floating point minimum value
static llvm::APFloat
getMaskNeutralValue(MaskNeutralFMinimum,
const llvm::fltSemantics &floatSemantics) {
return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
}
/// Create the mask neutral floating point MLIR vector constant
template <typename MaskNeutral>
static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Type vectorType) {
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
auto denseValue =
DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
}
/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
/// `fmaximum`/`fminimum`.
/// More information: https://github.com/llvm/llvm-project/issues/64940
template <class LLVMRedIntrinOp, class MaskNeutral>
static Value lowerMaskedReductionWithRegular(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, Value mask) {
const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
rewriter, loc, llvmType, vectorOperand.getType());
const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
loc, mask, vectorOperand, vectorMaskNeutral);
return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
rewriter, loc, llvmType, selectedVectorByMask, accumulator);
}
/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
/// a start value. This start value format spans across fp reductions without
/// mask and all the masked reduction intrinsics.
@ -903,10 +952,16 @@ public:
ReductionNeutralFPMin>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
default:
return rewriter.notifyMatchFailure(
maskOp,
"lowering to LLVM is not implemented for this masked operation");
case CombiningKind::MAXIMUMF:
result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
MaskNeutralFMaximum>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
case CombiningKind::MINIMUMF:
result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
MaskNeutralFMinimum>(
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
break;
}
// Replace `vector.mask` operation altogether.

View File

@ -101,6 +101,36 @@ func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
// -----
func.func @masked_reduce_maximumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <maximumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
return %0 : f32
}
// CHECK-LABEL: func.func @masked_reduce_maximumf_f32(
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<-1.401300e-45> : vector<16xf32>) : vector<16xf32>
// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fmaximum(%[[MASKED]]) : (vector<16xf32>) -> f32
// CHECK: return %[[RESULT]]
// -----
func.func @masked_reduce_minimumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <minimumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
return %0 : f32
}
// CHECK-LABEL: func.func @masked_reduce_minimumf_f32(
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<3.40282347E+38> : vector<16xf32>) : vector<16xf32>
// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fminimum(%[[MASKED]]) : (vector<16xf32>) -> f32
// CHECK: return %[[RESULT]]
// -----
func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8