mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-27 07:31:28 +00:00
[mlir][AMDGPU] Add packed 8-bit float conversion ops and lowering
Define operations that wrap the gfx940's new operations for converting between f32 and registers containing packed sets of four 8-bit floats. Define rocdl operations for the intrinsics and an AMDGPU dialect wrapper around them (to account for the fact that MLIR distinguishes the two float formats at the type level but that the LLVM IR does not). Define an ArithToAMDGPU pass, meant to run before conversion to LLVM, that replaces relevant calls to arith.extf and arith.truncf with the packed operations in the AMDGPU dialect. Note that the conversion currently only handles scalars and vectors of rank <= 1, as we do not have a usecase for multi-dimensional vector support right now. Reviewed By: jsjodin Differential Revision: https://reviews.llvm.org/D152457
This commit is contained in:
parent
0eed8ae7d2
commit
2ebd633f14
27
mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
Normal file
27
mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
Normal file
@ -0,0 +1,27 @@
|
||||
//===- ArithToAMDGPU.h - Arith to AMDGPU dialect conversion ---*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
|
||||
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class RewritePatternSet;
|
||||
class Pass;
|
||||
|
||||
#define GEN_PASS_DECL_ARITHTOAMDGPUCONVERSIONPASS
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
namespace arith {
|
||||
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
|
||||
} // namespace arith
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
|
@ -11,6 +11,7 @@
|
||||
|
||||
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
|
||||
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
|
||||
|
@ -112,6 +112,21 @@ def ConvertAMDGPUToROCDL : Pass<"convert-amdgpu-to-rocdl"> {
|
||||
"Chipset that these operations will run on">];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArithToAMDGPU
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
|
||||
let summary = "Convert Arith operations to AMDGPU-specific implementations";
|
||||
let description = [{
|
||||
Convert `arith` operations (currently extf and truncf on 8-bit floats)
|
||||
to operations in the `amdgpu` dialect. This pass is done in two steps
|
||||
in order to avoid running a notional arith-to-rocdl and arith-to-llvm
|
||||
simultaniously.
|
||||
}];
|
||||
|
||||
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArithToLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -38,6 +38,85 @@ def AMDGPU_Dialect : Dialect {
|
||||
class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<AMDGPU_Dialect, mnemonic, traits> {}
|
||||
|
||||
def AMDGPU_ExtPackedFp8Op :
|
||||
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
|
||||
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
|
||||
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
|
||||
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
|
||||
Results<(outs F32:$res)> {
|
||||
let summary = "Extend one of a vector of packed fp8 values to a float";
|
||||
let description = [{
|
||||
Extend the value `source[index]` to a 32-bit float and return it.
|
||||
|
||||
This rather unusual signature arises from the fact that AMD GPUs cannot
|
||||
easily work with sub 32-bit quantities, so the compiler intrinsics for
|
||||
extending 8-bit floats (which are, currently, the only way to work with
|
||||
this operation) take packed vectors of 4 such floats.
|
||||
|
||||
If the passed-in vector has fewer than four elements, or the input is scalar,
|
||||
the remaining values in the <4 x i8> will be filled with with
|
||||
undefined values as needed.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def AMDGPU_PackedTrunc2xFp8Op :
|
||||
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
|
||||
Arguments<(ins F32:$sourceA,
|
||||
Optional<F32>:$sourceB,
|
||||
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
|
||||
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
|
||||
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
|
||||
let summary = "Round two floats into a packed vector of 8-bit floats";
|
||||
let description = [{
|
||||
Round the inputs `sourceA` and `sourceB` (which is undefined if not
|
||||
specified) into the low or high word (bottom two or top two) elements
|
||||
of the returned vector, keeping the other two elements of `existing`
|
||||
unchanged if present (or undefined if it was not passed in).
|
||||
|
||||
The reason for this odd signature is that AMD GPUs cannot easily work with
|
||||
sub-registers, and so the conversion intrinsics (which are currently the
|
||||
only way to work with 8-bit float types) take packed vectors of 4 8-bit
|
||||
values.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $sourceA `,` ($sourceB^):(`undef`)?
|
||||
`into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
|
||||
`:` type($sourceA) `to` type($res) (`into` type($existing)^)?
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def AMDGPU_PackedStochRoundFp8Op :
|
||||
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
|
||||
Arguments<(ins F32:$source,
|
||||
I32:$stochiasticParam,
|
||||
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
|
||||
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
|
||||
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
|
||||
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
|
||||
let description = [{
|
||||
Round the input `source`, adding in `stochiasticParam`, and place it into
|
||||
the `storeIndex`th element of `res`.
|
||||
|
||||
If `existing` is passed in, elements of `res` other than the one at `storeIndex`
|
||||
are copied from `existing`.
|
||||
|
||||
The reason for this odd signature is that AMD GPUs cannot easily work with
|
||||
sub-registers, and so the conversion intrinsics (which are currently the
|
||||
only way to work with 8-bit float types) take packed vectors of 4 8-bit
|
||||
values.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $source `+` $stochiasticParam
|
||||
`into` ($existing^):(`undef`)? `[` $storeIndex `]`
|
||||
`:` type($source) `to` type($res) (`into` type($existing)^)?
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
/// Raw buffer load
|
||||
def AMDGPU_RawBufferLoadOp :
|
||||
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
|
||||
|
@ -116,7 +116,7 @@ class ROCDL_MbcntOp<string mnemonic> :
|
||||
def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">;
|
||||
def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">;
|
||||
|
||||
def ROCDL_DsSwizzleOp :
|
||||
def ROCDL_DsSwizzleOp :
|
||||
ROCDL_Op<"ds_swizzle">,
|
||||
Results<(outs I32:$res)>,
|
||||
Arguments<(ins I32:$src,
|
||||
@ -130,7 +130,7 @@ Arguments<(ins I32:$src,
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_DsBpermuteOp :
|
||||
def ROCDL_DsBpermuteOp :
|
||||
ROCDL_Op<"ds_bpermute">,
|
||||
Results<(outs I32:$res)>,
|
||||
Arguments<(ins I32:$index,
|
||||
@ -525,6 +525,85 @@ def ROCDL_RawBufferAtomicUMinOp :
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// 8-bit float intrinsics
|
||||
//===---------------------------------------------------------------------===//
|
||||
def ROCDL_CvtF32Bf8Op :
|
||||
ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$srcA, I32:$byteSel)> {
|
||||
let summary = "Convert bf8 to f32";
|
||||
let description = [{
|
||||
Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $srcA `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtF32Fp8Op :
|
||||
ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$srcA, I32:$byteSel)> {
|
||||
let summary = "Convert fp8 to f32";
|
||||
let description = [{
|
||||
Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $srcA `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtPkBf8F32Op :
|
||||
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
|
||||
let summary = "Convert two f32's to bf8";
|
||||
let description = [{
|
||||
Convert `srcA` and `srcB` to bf8 and store into the low/high word of
|
||||
`old`, preserving the other word.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtPkFp8F32Op :
|
||||
ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
|
||||
let summary = "Convert two f32's to fp8";
|
||||
let description = [{
|
||||
Convert `srcA` and `srcB` to fp8 and store into the low/high word of
|
||||
`old`, preserving the other word.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtSrBf8F32Op :
|
||||
ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
|
||||
let summary = "Convert f32 to bf8, stochiastic rounding";
|
||||
let description = [{
|
||||
Convert `srcA` to bf8, adding the rounding factor from `srcB`,
|
||||
and store into the `byteSel`th byte of `old`, preserving the others.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtSrFp8F32Op :
|
||||
ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
|
||||
let summary = "Convert f32 to fp8, stochiastic rounding";
|
||||
let description = [{
|
||||
Convert `srcA` to fp8, adding the rounding factor from `srcB`,
|
||||
and store into the `byteSel`th byte of `old`, preserving the others.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ROCDL target attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -612,5 +691,4 @@ def ROCDL_TargettAttr :
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // ROCDLIR_OPS
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
@ -639,6 +640,161 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct ExtPackedFp8OpLowering final
|
||||
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
|
||||
ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
|
||||
chipset(chipset) {}
|
||||
Chipset chipset;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct PackedTrunc2xFp8OpLowering final
|
||||
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
|
||||
PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
|
||||
chipset(chipset) {}
|
||||
Chipset chipset;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct PackedStochRoundFp8OpLowering final
|
||||
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
|
||||
PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
|
||||
chipset(chipset) {}
|
||||
Chipset chipset;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PackedStochRoundFp8Op op,
|
||||
PackedStochRoundFp8OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
|
||||
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "Fp8 conversion instructions are not available on target "
|
||||
"architecture and their emulation is not implemented");
|
||||
Type v4i8 =
|
||||
getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
|
||||
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
|
||||
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
|
||||
|
||||
Value source = adaptor.getSource();
|
||||
auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
|
||||
Type sourceElemType = getElementTypeOrSelf(op.getSource());
|
||||
// Extend to a v4i8
|
||||
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
|
||||
Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
|
||||
if (!sourceVecType) {
|
||||
longVec = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, longVec, source, createI32Constant(rewriter, loc, 0));
|
||||
} else {
|
||||
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
|
||||
Value idx = createI32Constant(rewriter, loc, i);
|
||||
Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
|
||||
longVec =
|
||||
rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
|
||||
}
|
||||
}
|
||||
source = longVec;
|
||||
}
|
||||
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
|
||||
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
|
||||
if (sourceElemType.isFloat8E5M2FNUZ()) {
|
||||
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
|
||||
wordSel);
|
||||
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
|
||||
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
|
||||
wordSel);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
|
||||
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "Fp8 conversion instructions are not available on target "
|
||||
"architecture and their emulation is not implemented");
|
||||
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
|
||||
|
||||
Type resultType = op.getResult().getType();
|
||||
Type resultElemType = getElementTypeOrSelf(resultType);
|
||||
|
||||
Value sourceA = adaptor.getSourceA();
|
||||
Value sourceB = adaptor.getSourceB();
|
||||
if (!sourceB)
|
||||
sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
|
||||
Value existing = adaptor.getExisting();
|
||||
if (existing)
|
||||
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
|
||||
else
|
||||
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
|
||||
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
|
||||
|
||||
Value result;
|
||||
if (resultElemType.isFloat8E5M2FNUZ())
|
||||
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
|
||||
existing, wordSel);
|
||||
else if (resultElemType.isFloat8E4M3FNUZ())
|
||||
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
|
||||
existing, wordSel);
|
||||
|
||||
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
|
||||
op, getTypeConverter()->convertType(resultType), result);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
|
||||
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "Fp8 conversion instructions are not available on target "
|
||||
"architecture and their emulation is not implemented");
|
||||
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
|
||||
|
||||
Type resultType = op.getResult().getType();
|
||||
Type resultElemType = getElementTypeOrSelf(resultType);
|
||||
|
||||
Value source = adaptor.getSource();
|
||||
Value stoch = adaptor.getStochiasticParam();
|
||||
Value existing = adaptor.getExisting();
|
||||
if (existing)
|
||||
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
|
||||
else
|
||||
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
|
||||
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
|
||||
|
||||
Value result;
|
||||
if (resultElemType.isFloat8E5M2FNUZ())
|
||||
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
|
||||
existing, byteSel);
|
||||
else if (resultElemType.isFloat8E4M3FNUZ())
|
||||
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
|
||||
existing, byteSel);
|
||||
|
||||
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
|
||||
op, getTypeConverter()->convertType(resultType), result);
|
||||
return success();
|
||||
}
|
||||
|
||||
struct ConvertAMDGPUToROCDLPass
|
||||
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
|
||||
ConvertAMDGPUToROCDLPass() = default;
|
||||
@ -691,7 +847,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
||||
ROCDL::RawPtrBufferAtomicUminOp>,
|
||||
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
|
||||
ROCDL::RawPtrBufferAtomicCmpSwap>,
|
||||
MFMAOpLowering, WMMAOpLowering>(converter, chipset);
|
||||
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
|
||||
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
|
||||
chipset);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
|
||||
|
210
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Normal file
210
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Normal file
@ -0,0 +1,210 @@
|
||||
//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
|
||||
|
||||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct ArithToAMDGPUConversionPass final
|
||||
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
|
||||
using impl::ArithToAMDGPUConversionPassBase<
|
||||
ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
|
||||
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
struct ExtfOnFloat8RewritePattern final
|
||||
: public OpRewritePattern<arith::ExtFOp> {
|
||||
using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult match(arith::ExtFOp op) const override;
|
||||
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct TruncfToFloat8RewritePattern final
|
||||
: public OpRewritePattern<arith::TruncFOp> {
|
||||
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult match(arith::TruncFOp op) const override;
|
||||
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
static Value castF32To(Type elementType, Value f32, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
if (elementType.isF32())
|
||||
return f32;
|
||||
if (elementType.getIntOrFloatBitWidth() < 32)
|
||||
return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
|
||||
if (elementType.getIntOrFloatBitWidth() > 32)
|
||||
return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
|
||||
llvm_unreachable("The only 32-bit float type is f32");
|
||||
}
|
||||
|
||||
LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
|
||||
Type inType = op.getIn().getType();
|
||||
if (auto inVecType = inType.dyn_cast<VectorType>()) {
|
||||
if (inVecType.isScalable())
|
||||
return failure();
|
||||
if (inVecType.getShape().size() > 1)
|
||||
// Multi-dimensional vectors are currently unsupported.
|
||||
return failure();
|
||||
inType = inVecType.getElementType();
|
||||
}
|
||||
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
|
||||
}
|
||||
|
||||
void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value in = op.getIn();
|
||||
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
|
||||
if (!in.getType().isa<VectorType>()) {
|
||||
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
|
||||
loc, rewriter.getF32Type(), in, 0);
|
||||
Value result = castF32To(outElemType, asFloat, loc, rewriter);
|
||||
return rewriter.replaceOp(op, result);
|
||||
}
|
||||
VectorType inType = in.getType().cast<VectorType>();
|
||||
int64_t numElements = inType.getNumElements();
|
||||
Value zero = rewriter.createOrFold<arith::ConstantOp>(
|
||||
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||||
Value result =
|
||||
rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
|
||||
if (inType.getShape().empty()) {
|
||||
Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
|
||||
// Recurse to send the 0-D vector case to the 1-D vector case
|
||||
Value scalarExt =
|
||||
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
|
||||
result = rewriter.create<vector::InsertElementOp>(loc, scalarExt, zero);
|
||||
return rewriter.replaceOp(op, result);
|
||||
}
|
||||
for (int64_t i = 0; i < numElements; i += 4) {
|
||||
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
|
||||
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, in, i, elemsThisOp, 1);
|
||||
for (int64_t j = 0; j < elemsThisOp; ++j) {
|
||||
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
|
||||
loc, rewriter.getF32Type(), inSlice, j);
|
||||
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
|
||||
result = rewriter.create<vector::InsertElementOp>(
|
||||
loc, asType, result,
|
||||
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j));
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
}
|
||||
|
||||
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
|
||||
Type type = value.getType();
|
||||
if (type.isF32())
|
||||
return value;
|
||||
if (type.getIntOrFloatBitWidth() < 32)
|
||||
return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
|
||||
if (type.getIntOrFloatBitWidth() > 32)
|
||||
return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
|
||||
llvm_unreachable("The only 32-bit float type is f32");
|
||||
}
|
||||
|
||||
LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
|
||||
Type outType = op.getOut().getType();
|
||||
if (auto outVecType = outType.dyn_cast<VectorType>()) {
|
||||
if (outVecType.isScalable())
|
||||
return failure();
|
||||
if (outVecType.getShape().size() > 1)
|
||||
// Multi-dimensional vectors are currently unsupported.
|
||||
return failure();
|
||||
outType = outVecType.getElementType();
|
||||
}
|
||||
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
|
||||
}
|
||||
|
||||
void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value in = op.getIn();
|
||||
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
|
||||
VectorType truncResType = VectorType::get(4, outElemType);
|
||||
if (!in.getType().isa<VectorType>()) {
|
||||
Value asFloat = castToF32(in, loc, rewriter);
|
||||
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
|
||||
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
|
||||
/*existing=*/nullptr);
|
||||
Value result = rewriter.create<vector::ExtractElementOp>(
|
||||
loc, asF8s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
|
||||
return rewriter.replaceOp(op, result);
|
||||
}
|
||||
VectorType outType = op.getOut().getType().cast<VectorType>();
|
||||
int64_t numElements = outType.getNumElements();
|
||||
Value zero = rewriter.createOrFold<arith::ConstantOp>(
|
||||
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||||
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
|
||||
if (outType.getShape().empty()) {
|
||||
Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
|
||||
// Recurse to send the 0-D vector case to the 1-D vector case
|
||||
Value scalarTrunc =
|
||||
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
|
||||
result = rewriter.create<vector::InsertElementOp>(loc, scalarTrunc, zero);
|
||||
return rewriter.replaceOp(op, result);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < numElements; i += 4) {
|
||||
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
|
||||
Value thisResult = nullptr;
|
||||
for (int64_t j = 0; j < elemsThisOp; j += 2) {
|
||||
Value elemA = rewriter.create<vector::ExtractElementOp>(
|
||||
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i + j));
|
||||
Value asFloatA = castToF32(elemA, loc, rewriter);
|
||||
Value asFloatB = nullptr;
|
||||
if (j + 1 < elemsThisOp) {
|
||||
Value elemB = rewriter.create<vector::ExtractElementOp>(
|
||||
loc, in,
|
||||
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j + 1));
|
||||
asFloatB = castToF32(elemB, loc, rewriter);
|
||||
}
|
||||
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
|
||||
loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
|
||||
}
|
||||
if (elemsThisOp < 4)
|
||||
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, thisResult, 0, elemsThisOp, 1);
|
||||
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
|
||||
result, i, 1);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
}
|
||||
|
||||
void mlir::arith::populateArithToAMDGPUConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void ArithToAMDGPUConversionPass::runOnOperation() {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
arith::populateArithToAMDGPUConversionPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
19
mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
Normal file
19
mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
Normal file
@ -0,0 +1,19 @@
|
||||
add_mlir_conversion_library(MLIRArithToAMDGPU
|
||||
ArithToAMDGPU.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToAMDGPU
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAMDGPUDialect
|
||||
MLIRArithDialect
|
||||
MLIRVectorDialect
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
)
|
@ -1,6 +1,7 @@
|
||||
add_subdirectory(AffineToStandard)
|
||||
add_subdirectory(AMDGPUToROCDL)
|
||||
add_subdirectory(ArithCommon)
|
||||
add_subdirectory(ArithToAMDGPU)
|
||||
add_subdirectory(ArithToLLVM)
|
||||
add_subdirectory(ArithToSPIRV)
|
||||
add_subdirectory(ArmNeon2dToIntr)
|
||||
|
@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
|
||||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// 8-bit float ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
LogicalResult PackedTrunc2xFp8Op::verify() {
|
||||
if (getExisting() && getExisting().getType() != getResult().getType())
|
||||
return emitOpError("existing values must have same type as result");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PackedStochRoundFp8Op::verify() {
|
||||
if (getExisting() && getExisting().getType() != getResult().getType())
|
||||
return emitOpError("existing values must have same type as result");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RawBuffer*Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
108
mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
Normal file
108
mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
Normal file
@ -0,0 +1,108 @@
|
||||
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @ext_scalar
|
||||
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2FNUZ to i8
|
||||
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
|
||||
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
|
||||
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
|
||||
// CHECK: return [[EXT]]
|
||||
func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
|
||||
func.return %ret : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ext_short_vec
|
||||
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8>
|
||||
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
|
||||
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
|
||||
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
|
||||
// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
|
||||
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
|
||||
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
|
||||
// CHECK: return [[EXT]]
|
||||
func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
|
||||
func.return %ret : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ext_full_vec(
|
||||
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
|
||||
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
|
||||
// CHECK: return [[EXT]] : f32
|
||||
|
||||
func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
|
||||
func.return %ret : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @packed_trunc
|
||||
// CHECK-SAME: ([[V:%.+]]: f32)
|
||||
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
|
||||
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
|
||||
%ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
|
||||
func.return %ret : vector<4xf8E4M3FNUZ>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @packed_truncx2
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
|
||||
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
|
||||
%ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
|
||||
func.return %ret : vector<4xf8E4M3FNUZ>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @packed_truncx2_into
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
|
||||
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
|
||||
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
|
||||
// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
|
||||
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
|
||||
%ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
|
||||
func.return %ret : vector<4xf8E5M2FNUZ>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @packed_stoch_round
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
|
||||
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
|
||||
%ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FNUZ>
|
||||
func.return %ret : vector<4xf8E4M3FNUZ>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @packed_stoch_round_into
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
|
||||
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
|
||||
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
|
||||
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
|
||||
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
|
||||
%ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
|
||||
func.return %ret : vector<4xf8E5M2FNUZ>
|
||||
}
|
122
mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
Normal file
122
mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
Normal file
@ -0,0 +1,122 @@
|
||||
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @scalar_ext
|
||||
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
|
||||
// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32
|
||||
// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
|
||||
// CHECK: return [[W]]
|
||||
func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
|
||||
%w = arith.extf %v : f8E5M2FNUZ to f16
|
||||
return %w : f16
|
||||
}
|
||||
|
||||
// No 0-D test because arith.extf hasn't been extended to support it.
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @vector_ext_short
|
||||
// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
|
||||
// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
|
||||
// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
|
||||
// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
|
||||
// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
|
||||
// CHECK: [[W0:%.+]] = vector.insertelement [[EXT0]], [[ZEROES]]{{\[}}[[C0]]
|
||||
// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
|
||||
// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
|
||||
// CHECK: [[W1:%.+]] = vector.insertelement [[EXT1]], [[W0]]{{\[}}[[C1]]
|
||||
// CHECK: return [[W1]] : vector<2xf64>
|
||||
|
||||
func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
|
||||
%w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64>
|
||||
return %w : vector<2xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @vector_ext_long
|
||||
// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
|
||||
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
|
||||
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
|
||||
// CHECK: [[W0:%.+]] = vector.insertelement [[F0]]
|
||||
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
|
||||
// CHECK: [[W1:%.+]] = vector.insertelement [[F1]], [[W0]]
|
||||
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
|
||||
// CHECK: [[W2:%.+]] = vector.insertelement [[F2]], [[W1]]
|
||||
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
|
||||
// CHECK: [[W3:%.+]] = vector.insertelement [[F3]], [[W2]]
|
||||
|
||||
// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
|
||||
// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
|
||||
// CHECK: [[W4:%.+]] = vector.insertelement [[F4]], [[W3]]
|
||||
// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
|
||||
// CHECK: [[W5:%.+]] = vector.insertelement [[F5]], [[W4]]
|
||||
// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
|
||||
// CHECK: [[W6:%.+]] = vector.insertelement [[F6]], [[W5]]
|
||||
// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
|
||||
// CHECK: [[W7:%.+]] = vector.insertelement [[F7]], [[W6]]
|
||||
|
||||
// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
|
||||
// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
|
||||
// CHECK: [[W8:%.+]] = vector.insertelement [[F8]], [[W7]]
|
||||
// CHECK: return [[W8]]
|
||||
func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
|
||||
%w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
|
||||
return %w : vector<9xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @scalar_trunc
|
||||
// CHECK-SAME: ([[V:%.+]]: f16)
|
||||
// CHECK: [[C0:%.+]] = arith.constant 0 : index
|
||||
// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
|
||||
// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
|
||||
// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ>
|
||||
// CHECK: return [[W]] : f8E5M2FNUZ
|
||||
func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
|
||||
%w = arith.truncf %v : f16 to f8E5M2FNUZ
|
||||
return %w : f8E5M2FNUZ
|
||||
}
|
||||
|
||||
// No 0-D test because arith.truncf hasn't been extended to support it.
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @vector_trunc_short
|
||||
// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
|
||||
// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
|
||||
// CHECK: [[V0:%.+]] = vector.extractelement [[V]]{{\[}}[[C0]] : index]
|
||||
// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
|
||||
// CHECK: [[V1:%.+]] = vector.extractelement [[V]]{{\[}}[[C1]] : index]
|
||||
// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
|
||||
// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
|
||||
// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ>
|
||||
// CHECK: return [[W]] : vector<2xf8E5M2FNUZ>
|
||||
func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
|
||||
%w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2FNUZ>
|
||||
return %w : vector<2xf8E5M2FNUZ>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @vector_trunc_long
|
||||
// CHECK-SAME: ([[V:%.+]]: vector<9xf32>)
|
||||
// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
|
||||
// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
|
||||
// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
|
||||
// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
|
||||
|
||||
// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
|
||||
// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
|
||||
// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
|
||||
|
||||
// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
|
||||
// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
|
||||
// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
|
||||
// CHECK: return [[W]]
|
||||
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
|
||||
%w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
|
||||
return %w : vector<9xf8E4M3FNUZ>
|
||||
}
|
@ -1,5 +1,19 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
func.func @mixing_packed_trunc_types(%arg0: f32, %arg1: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> {
|
||||
// expected-error@+1 {{'amdgpu.packed_trunc_2xfp8' op existing values must have same type as result}}
|
||||
%ret = amdgpu.packed_trunc_2xfp8 %arg0, undef into %arg1[word 0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ>
|
||||
func.return %ret : vector<4xf8E4M3FNUZ>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @mixing_packed_stoch_round_types(%arg0: f32, %arg1: i32, %arg2: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> {
|
||||
// expected-error@+1 {{'amdgpu.packed_stoch_round_fp8' op existing values must have same type as result}}
|
||||
%ret = amdgpu.packed_stoch_round_fp8 %arg0 + %arg1 into %arg2[0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ>
|
||||
func.return %ret : vector<4xf8E4M3FNUZ>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>,
|
||||
|
@ -4,6 +4,27 @@
|
||||
// Verify the generic form can be parsed.
|
||||
// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @ext_packed_fp8
|
||||
// CHECK: amdgpu.ext_packed_fp8
|
||||
func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32
|
||||
func.return %ret : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @packed_trunc_2xfp8
|
||||
// CHECK: amdgpu.packed_trunc_2xfp8
|
||||
func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
|
||||
%ret = amdgpu.packed_trunc_2xfp8 %v1, %v2 into %others[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
|
||||
func.return %ret : vector<4xf8E5M2FNUZ>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @packed_stoch_round_fp8
|
||||
// CHECK: amdgpu.packed_stoch_round_fp8
|
||||
func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
|
||||
%ret = amdgpu.packed_stoch_round_fp8 %v1 + %stoch into %others[2] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
|
||||
func.return %ret : vector<4xf8E5M2FNUZ>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
|
||||
func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 {
|
||||
// CHECK: amdgpu.raw_buffer_load {indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32
|
||||
|
@ -330,6 +330,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
|
||||
// CHECK-LABEL: @rocdl_8bit_floats
|
||||
// CHECK: rocdl.cvt.f32.bf8
|
||||
// CHECK: rocdl.cvt.f32.fp8
|
||||
// CHECK: rocdl.cvt.pk.bf8.f32
|
||||
// CHECK: rocdl.cvt.pk.fp8.f32
|
||||
// CHECK: rocdl.cvt.sr.bf8.f32
|
||||
// CHECK: rocdl.cvt.sr.fp8.f32
|
||||
%c0 = llvm.mlir.constant(0 : i32) : i32
|
||||
%c2 = llvm.mlir.constant(2 : i32) : i32
|
||||
%c3 = llvm.mlir.constant(3 : i32) : i32
|
||||
%false = llvm.mlir.constant(false) : i1
|
||||
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
|
||||
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
|
||||
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
|
||||
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
|
||||
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
|
||||
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
|
||||
llvm.return %source5 : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@below {{attribute attached to unexpected op}}
|
||||
|
@ -468,6 +468,27 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
|
||||
llvm.return %val : i32
|
||||
}
|
||||
|
||||
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
|
||||
// CHECK-LABEL: @rocdl_8bit_floats
|
||||
// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
|
||||
// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
|
||||
// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
|
||||
// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
|
||||
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
|
||||
// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
|
||||
%c0 = llvm.mlir.constant(0 : i32) : i32
|
||||
%c2 = llvm.mlir.constant(2 : i32) : i32
|
||||
%c3 = llvm.mlir.constant(3 : i32) : i32
|
||||
%false = llvm.mlir.constant(false) : i1
|
||||
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
|
||||
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
|
||||
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
|
||||
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
|
||||
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
|
||||
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
|
||||
llvm.return %source5 : i32
|
||||
}
|
||||
|
||||
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "amdgpu-implicitarg-num-bytes"="56" }
|
||||
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
|
||||
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"
|
||||
|
Loading…
Reference in New Issue
Block a user