From 2ebd633f145615a42d7e8b1d07cbdad294c244aa Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Fri, 12 May 2023 15:40:29 +0000 Subject: [PATCH] [mlir][AMDGPU] Add packed 8-bit float conversion ops and lowering Define operations that wrap the gfx940's new operations for converting between f32 and registers containing packed sets of four 8-bit floats. Define rocdl operations for the intrinsics and an AMDGPU dialect wrapper around them (to account for the fact that MLIR distinguishes the two float formats at the type level but that the LLVM IR does not). Define an ArithToAMDGPU pass, meant to run before conversion to LLVM, that replaces relevant calls to arith.extf and arith.truncf with the packed operations in the AMDGPU dialect. Note that the conversion currently only handles scalars and vectors of rank <= 1, as we do not have a usecase for multi-dimensional vector support right now. Reviewed By: jsjodin Differential Revision: https://reviews.llvm.org/D152457 --- .../Conversion/ArithToAMDGPU/ArithToAMDGPU.h | 27 +++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 15 ++ mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 79 +++++++ mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 84 ++++++- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 160 ++++++++++++- .../ArithToAMDGPU/ArithToAMDGPU.cpp | 210 ++++++++++++++++++ .../Conversion/ArithToAMDGPU/CMakeLists.txt | 19 ++ mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 15 ++ .../AMDGPUToROCDL/8-bit-floats.mlir | 108 +++++++++ .../ArithToAMDGPU/8-bit-floats.mlir | 122 ++++++++++ mlir/test/Dialect/AMDGPU/invalid.mlir | 14 ++ mlir/test/Dialect/AMDGPU/ops.mlir | 21 ++ mlir/test/Dialect/LLVMIR/rocdl.mlir | 21 ++ mlir/test/Target/LLVMIR/rocdl.mlir | 21 ++ 16 files changed, 914 insertions(+), 4 deletions(-) create mode 100644 mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h create mode 100644 mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp create mode 100644 mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir create mode 100644 mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h new file mode 100644 index 000000000000..7f445fee5ba6 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h @@ -0,0 +1,27 @@ +//===- ArithToAMDGPU.h - Arith to AMDGPU dialect conversion ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H +#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H + +#include + +namespace mlir { + +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_ARITHTOAMDGPUCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" + +namespace arith { +void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns); +} // namespace arith +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 41806004fc1d..e714f5070f23 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -11,6 +11,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 9b7848d9288b..38b05c792d40 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -112,6 +112,21 @@ def ConvertAMDGPUToROCDL : Pass<"convert-amdgpu-to-rocdl"> { "Chipset that these operations will run on">]; } +//===----------------------------------------------------------------------===// +// ArithToAMDGPU +//===----------------------------------------------------------------------===// +def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> { + let summary = "Convert Arith operations to AMDGPU-specific implementations"; + let description = [{ + Convert `arith` operations (currently extf and truncf on 8-bit floats) + to operations in the `amdgpu` dialect. This pass is done in two steps + in order to avoid running a notional arith-to-rocdl and arith-to-llvm + simultaniously. + }]; + + let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"]; +} + //===----------------------------------------------------------------------===// // ArithToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 6d788e3a9701..ffb302fcedd7 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -38,6 +38,85 @@ def AMDGPU_Dialect : Dialect { class AMDGPU_Op traits = []> : Op {} +def AMDGPU_ExtPackedFp8Op : + AMDGPU_Op<"ext_packed_fp8", [Pure]>, + Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, + VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source, + ConfinedAttr]>:$index)>, + Results<(outs F32:$res)> { + let summary = "Extend one of a vector of packed fp8 values to a float"; + let description = [{ + Extend the value `source[index]` to a 32-bit float and return it. + + This rather unusual signature arises from the fact that AMD GPUs cannot + easily work with sub 32-bit quantities, so the compiler intrinsics for + extending 8-bit floats (which are, currently, the only way to work with + this operation) take packed vectors of 4 such floats. + + If the passed-in vector has fewer than four elements, or the input is scalar, + the remaining values in the <4 x i8> will be filled with with + undefined values as needed. + }]; + let assemblyFormat = [{ + attr-dict $source `[` $index `]` `:` type($source) `to` type($res) + }]; +} + +def AMDGPU_PackedTrunc2xFp8Op : + AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>, + Arguments<(ins F32:$sourceA, + Optional:$sourceB, + ConfinedAttr]>:$wordIndex, + Optional>:$existing)>, + Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> { + let summary = "Round two floats into a packed vector of 8-bit floats"; + let description = [{ + Round the inputs `sourceA` and `sourceB` (which is undefined if not + specified) into the low or high word (bottom two or top two) elements + of the returned vector, keeping the other two elements of `existing` + unchanged if present (or undefined if it was not passed in). + + The reason for this odd signature is that AMD GPUs cannot easily work with + sub-registers, and so the conversion intrinsics (which are currently the + only way to work with 8-bit float types) take packed vectors of 4 8-bit + values. + }]; + let assemblyFormat = [{ + attr-dict $sourceA `,` ($sourceB^):(`undef`)? + `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]` + `:` type($sourceA) `to` type($res) (`into` type($existing)^)? + }]; + let hasVerifier = 1; +} + +def AMDGPU_PackedStochRoundFp8Op : + AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>, + Arguments<(ins F32:$source, + I32:$stochiasticParam, + ConfinedAttr]>:$storeIndex, + Optional>:$existing)>, + Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> { + let summary = "Round float stochiastically into a packed vector of 8-bit floats"; + let description = [{ + Round the input `source`, adding in `stochiasticParam`, and place it into + the `storeIndex`th element of `res`. + + If `existing` is passed in, elements of `res` other than the one at `storeIndex` + are copied from `existing`. + + The reason for this odd signature is that AMD GPUs cannot easily work with + sub-registers, and so the conversion intrinsics (which are currently the + only way to work with 8-bit float types) take packed vectors of 4 8-bit + values. + }]; + let assemblyFormat = [{ + attr-dict $source `+` $stochiasticParam + `into` ($existing^):(`undef`)? `[` $storeIndex `]` + `:` type($source) `to` type($res) (`into` type($existing)^)? + }]; + let hasVerifier = 1; +} + /// Raw buffer load def AMDGPU_RawBufferLoadOp : AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>, diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 08d36397dc31..6c6419bf238b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -116,7 +116,7 @@ class ROCDL_MbcntOp : def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">; def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">; -def ROCDL_DsSwizzleOp : +def ROCDL_DsSwizzleOp : ROCDL_Op<"ds_swizzle">, Results<(outs I32:$res)>, Arguments<(ins I32:$src, @@ -130,7 +130,7 @@ Arguments<(ins I32:$src, }]; } -def ROCDL_DsBpermuteOp : +def ROCDL_DsBpermuteOp : ROCDL_Op<"ds_bpermute">, Results<(outs I32:$res)>, Arguments<(ins I32:$index, @@ -525,6 +525,85 @@ def ROCDL_RawBufferAtomicUMinOp : let hasCustomAssemblyFormat = 1; } +//===---------------------------------------------------------------------===// +// 8-bit float intrinsics +//===---------------------------------------------------------------------===// +def ROCDL_CvtF32Bf8Op : + ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>, + Arguments<(ins I32:$srcA, I32:$byteSel)> { + let summary = "Convert bf8 to f32"; + let description = [{ + Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32. + }]; + let assemblyFormat = [{ + attr-dict $srcA `[` $byteSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtF32Fp8Op : + ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>, + Arguments<(ins I32:$srcA, I32:$byteSel)> { + let summary = "Convert fp8 to f32"; + let description = [{ + Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32. + }]; + let assemblyFormat = [{ + attr-dict $srcA `[` $byteSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtPkBf8F32Op : + ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>, + Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> { + let summary = "Convert two f32's to bf8"; + let description = [{ + Convert `srcA` and `srcB` to bf8 and store into the low/high word of + `old`, preserving the other word. + }]; + let assemblyFormat = [{ + attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtPkFp8F32Op : + ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>, + Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> { + let summary = "Convert two f32's to fp8"; + let description = [{ + Convert `srcA` and `srcB` to fp8 and store into the low/high word of + `old`, preserving the other word. + }]; + let assemblyFormat = [{ + attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtSrBf8F32Op : + ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>, + Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> { + let summary = "Convert f32 to bf8, stochiastic rounding"; + let description = [{ + Convert `srcA` to bf8, adding the rounding factor from `srcB`, + and store into the `byteSel`th byte of `old`, preserving the others. + }]; + let assemblyFormat = [{ + attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtSrFp8F32Op : + ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>, + Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> { + let summary = "Convert f32 to fp8, stochiastic rounding"; + let description = [{ + Convert `srcA` to fp8, adding the rounding factor from `srcB`, + and store into the `byteSel`th byte of `old`, preserving the others. + }]; + let assemblyFormat = [{ + attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res) + }]; +} + //===----------------------------------------------------------------------===// // ROCDL target attribute. //===----------------------------------------------------------------------===// @@ -612,5 +691,4 @@ def ROCDL_TargettAttr : } }]; } - #endif // ROCDLIR_OPS diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index eeed04049668..9ed312cef744 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -639,6 +640,161 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { } }; +namespace { +struct ExtPackedFp8OpLowering final + : public ConvertOpToLLVMPattern { + ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct PackedTrunc2xFp8OpLowering final + : public ConvertOpToLLVMPattern { + PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct PackedStochRoundFp8OpLowering final + : public ConvertOpToLLVMPattern { + PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(PackedStochRoundFp8Op op, + PackedStochRoundFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // end namespace + +LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( + ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + return rewriter.notifyMatchFailure( + loc, "Fp8 conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + Type v4i8 = + getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); + Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); + Type f32 = getTypeConverter()->convertType(op.getResult().getType()); + + Value source = adaptor.getSource(); + auto sourceVecType = op.getSource().getType().dyn_cast(); + Type sourceElemType = getElementTypeOrSelf(op.getSource()); + // Extend to a v4i8 + if (!sourceVecType || sourceVecType.getNumElements() < 4) { + Value longVec = rewriter.create(loc, v4i8); + if (!sourceVecType) { + longVec = rewriter.create( + loc, longVec, source, createI32Constant(rewriter, loc, 0)); + } else { + for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { + Value idx = createI32Constant(rewriter, loc, i); + Value elem = rewriter.create(loc, source, idx); + longVec = + rewriter.create(loc, longVec, elem, idx); + } + } + source = longVec; + } + Value i32Source = rewriter.create(loc, i32, source); + Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); + if (sourceElemType.isFloat8E5M2FNUZ()) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); + } else if (sourceElemType.isFloat8E4M3FNUZ()) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); + } + return success(); +} + +LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( + PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + return rewriter.notifyMatchFailure( + loc, "Fp8 conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); + + Type resultType = op.getResult().getType(); + Type resultElemType = getElementTypeOrSelf(resultType); + + Value sourceA = adaptor.getSourceA(); + Value sourceB = adaptor.getSourceB(); + if (!sourceB) + sourceB = rewriter.create(loc, sourceA.getType()); + Value existing = adaptor.getExisting(); + if (existing) + existing = rewriter.create(loc, i32, existing); + else + existing = rewriter.create(loc, i32); + Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); + + Value result; + if (resultElemType.isFloat8E5M2FNUZ()) + result = rewriter.create(loc, i32, sourceA, sourceB, + existing, wordSel); + else if (resultElemType.isFloat8E4M3FNUZ()) + result = rewriter.create(loc, i32, sourceA, sourceB, + existing, wordSel); + + result = rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(resultType), result); + return success(); +} + +LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( + PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + return rewriter.notifyMatchFailure( + loc, "Fp8 conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); + + Type resultType = op.getResult().getType(); + Type resultElemType = getElementTypeOrSelf(resultType); + + Value source = adaptor.getSource(); + Value stoch = adaptor.getStochiasticParam(); + Value existing = adaptor.getExisting(); + if (existing) + existing = rewriter.create(loc, i32, existing); + else + existing = rewriter.create(loc, i32); + Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); + + Value result; + if (resultElemType.isFloat8E5M2FNUZ()) + result = rewriter.create(loc, i32, source, stoch, + existing, byteSel); + else if (resultElemType.isFloat8E4M3FNUZ()) + result = rewriter.create(loc, i32, source, stoch, + existing, byteSel); + + result = rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(resultType), result); + return success(); +} + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLBase { ConvertAMDGPUToROCDLPass() = default; @@ -691,7 +847,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicUminOp>, RawBufferOpLowering, - MFMAOpLowering, WMMAOpLowering>(converter, chipset); + MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, + PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter, + chipset); } std::unique_ptr mlir::createConvertAMDGPUToROCDLPass() { diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp new file mode 100644 index 000000000000..7785405eae67 --- /dev/null +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -0,0 +1,210 @@ +//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ArithToAMDGPUConversionPass final + : impl::ArithToAMDGPUConversionPassBase { + using impl::ArithToAMDGPUConversionPassBase< + ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; + + void runOnOperation() override; +}; + +struct ExtfOnFloat8RewritePattern final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(arith::ExtFOp op) const override; + void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; +}; + +struct TruncfToFloat8RewritePattern final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(arith::TruncFOp op) const override; + void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; +}; +} // end namespace + +static Value castF32To(Type elementType, Value f32, Location loc, + PatternRewriter &rewriter) { + if (elementType.isF32()) + return f32; + if (elementType.getIntOrFloatBitWidth() < 32) + return rewriter.create(loc, elementType, f32); + if (elementType.getIntOrFloatBitWidth() > 32) + return rewriter.create(loc, elementType, f32); + llvm_unreachable("The only 32-bit float type is f32"); +} + +LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const { + Type inType = op.getIn().getType(); + if (auto inVecType = inType.dyn_cast()) { + if (inVecType.isScalable()) + return failure(); + if (inVecType.getShape().size() > 1) + // Multi-dimensional vectors are currently unsupported. + return failure(); + inType = inVecType.getElementType(); + } + return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ()); +} + +void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value in = op.getIn(); + Type outElemType = getElementTypeOrSelf(op.getOut().getType()); + if (!in.getType().isa()) { + Value asFloat = rewriter.create( + loc, rewriter.getF32Type(), in, 0); + Value result = castF32To(outElemType, asFloat, loc, rewriter); + return rewriter.replaceOp(op, result); + } + VectorType inType = in.getType().cast(); + int64_t numElements = inType.getNumElements(); + Value zero = rewriter.createOrFold( + loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value result = + rewriter.createOrFold(loc, op.getOut().getType(), zero); + if (inType.getShape().empty()) { + Value scalarIn = rewriter.create(loc, in); + // Recurse to send the 0-D vector case to the 1-D vector case + Value scalarExt = + rewriter.create(loc, outElemType, scalarIn); + result = rewriter.create(loc, scalarExt, zero); + return rewriter.replaceOp(op, result); + } + for (int64_t i = 0; i < numElements; i += 4) { + int64_t elemsThisOp = std::min(numElements, i + 4) - i; + Value inSlice = rewriter.create( + loc, in, i, elemsThisOp, 1); + for (int64_t j = 0; j < elemsThisOp; ++j) { + Value asFloat = rewriter.create( + loc, rewriter.getF32Type(), inSlice, j); + Value asType = castF32To(outElemType, asFloat, loc, rewriter); + result = rewriter.create( + loc, asType, result, + rewriter.createOrFold(loc, i + j)); + } + } + rewriter.replaceOp(op, result); +} + +static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { + Type type = value.getType(); + if (type.isF32()) + return value; + if (type.getIntOrFloatBitWidth() < 32) + return rewriter.create(loc, rewriter.getF32Type(), value); + if (type.getIntOrFloatBitWidth() > 32) + return rewriter.create(loc, rewriter.getF32Type(), value); + llvm_unreachable("The only 32-bit float type is f32"); +} + +LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const { + Type outType = op.getOut().getType(); + if (auto outVecType = outType.dyn_cast()) { + if (outVecType.isScalable()) + return failure(); + if (outVecType.getShape().size() > 1) + // Multi-dimensional vectors are currently unsupported. + return failure(); + outType = outVecType.getElementType(); + } + return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()); +} + +void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value in = op.getIn(); + Type outElemType = getElementTypeOrSelf(op.getOut().getType()); + VectorType truncResType = VectorType::get(4, outElemType); + if (!in.getType().isa()) { + Value asFloat = castToF32(in, loc, rewriter); + Value asF8s = rewriter.create( + loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, + /*existing=*/nullptr); + Value result = rewriter.create( + loc, asF8s, rewriter.createOrFold(loc, 0)); + return rewriter.replaceOp(op, result); + } + VectorType outType = op.getOut().getType().cast(); + int64_t numElements = outType.getNumElements(); + Value zero = rewriter.createOrFold( + loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value result = rewriter.createOrFold(loc, outType, zero); + if (outType.getShape().empty()) { + Value scalarIn = rewriter.create(loc, in); + // Recurse to send the 0-D vector case to the 1-D vector case + Value scalarTrunc = + rewriter.create(loc, outElemType, scalarIn); + result = rewriter.create(loc, scalarTrunc, zero); + return rewriter.replaceOp(op, result); + } + + for (int64_t i = 0; i < numElements; i += 4) { + int64_t elemsThisOp = std::min(numElements, i + 4) - i; + Value thisResult = nullptr; + for (int64_t j = 0; j < elemsThisOp; j += 2) { + Value elemA = rewriter.create( + loc, in, rewriter.create(loc, i + j)); + Value asFloatA = castToF32(elemA, loc, rewriter); + Value asFloatB = nullptr; + if (j + 1 < elemsThisOp) { + Value elemB = rewriter.create( + loc, in, + rewriter.createOrFold(loc, i + j + 1)); + asFloatB = castToF32(elemB, loc, rewriter); + } + thisResult = rewriter.create( + loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); + } + if (elemsThisOp < 4) + thisResult = rewriter.create( + loc, thisResult, 0, elemsThisOp, 1); + result = rewriter.create(loc, thisResult, + result, i, 1); + } + rewriter.replaceOp(op, result); +} + +void mlir::arith::populateArithToAMDGPUConversionPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void ArithToAMDGPUConversionPass::runOnOperation() { + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + arith::populateArithToAMDGPUConversionPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt new file mode 100644 index 000000000000..359015b6f86a --- /dev/null +++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArithToAMDGPU + ArithToAMDGPU.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToAMDGPU + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRAMDGPUDialect + MLIRArithDialect + MLIRVectorDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 660e48768c4f..35790254be13 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) +add_subdirectory(ArithToAMDGPU) add_subdirectory(ArithToLLVM) add_subdirectory(ArithToSPIRV) add_subdirectory(ArmNeon2dToIntr) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index ac34acc83074..2575ad498481 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// 8-bit float ops +//===----------------------------------------------------------------------===// +LogicalResult PackedTrunc2xFp8Op::verify() { + if (getExisting() && getExisting().getType() != getResult().getType()) + return emitOpError("existing values must have same type as result"); + return success(); +} + +LogicalResult PackedStochRoundFp8Op::verify() { + if (getExisting() && getExisting().getType() != getResult().getType()) + return emitOpError("existing values must have same type as result"); + return success(); +} + //===----------------------------------------------------------------------===// // RawBuffer*Op //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir new file mode 100644 index 000000000000..7818a525d17b --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s + +// CHECK-LABEL: func @ext_scalar +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2FNUZ to i8 +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 +// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: return [[EXT]] +func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_short_vec +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8> +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8> +// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8> +// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> +// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: return [[EXT]] +func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_full_vec( +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 +// CHECK: return [[EXT]] : f32 + +func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @packed_trunc +// CHECK-SAME: ([[V:%.+]]: f32) +// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32 +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> +func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> { + %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + +// CHECK-LABEL: func @packed_truncx2 +// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32) +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> +func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> { + %ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + +// CHECK-LABEL: func @packed_truncx2_into +// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>) +// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8> +// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 +// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ> +func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> { + %ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E5M2FNUZ> +} + +// CHECK-LABEL: func @packed_stoch_round +// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32) +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> +func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> { + %ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + +// CHECK-LABEL: func @packed_stoch_round_into +// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>) +// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8> +// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 +// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ> +func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> { + %ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E5M2FNUZ> +} diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir new file mode 100644 index 000000000000..a6c11d022e2c --- /dev/null +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir @@ -0,0 +1,122 @@ +// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s + +// CHECK-LABEL: func.func @scalar_ext +// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ) +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32 +// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16 +// CHECK: return [[W]] +func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 { + %w = arith.extf %v : f8E5M2FNUZ to f16 + return %w : f16 +} + +// No 0-D test because arith.extf hasn't been extended to support it. + +// ----- + +// CHECK-LABEL: func.func @vector_ext_short +// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>) +// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64> +// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index +// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32 +// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64 +// CHECK: [[W0:%.+]] = vector.insertelement [[EXT0]], [[ZEROES]]{{\[}}[[C0]] +// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32 +// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]] +// CHECK: [[W1:%.+]] = vector.insertelement [[EXT1]], [[W0]]{{\[}}[[C1]] +// CHECK: return [[W1]] : vector<2xf64> + +func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { + %w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64> + return %w : vector<2xf64> +} + +// ----- + +// CHECK-LABEL: func.func @vector_ext_long +// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>) +// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} +// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] +// CHECK: [[W0:%.+]] = vector.insertelement [[F0]] +// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] +// CHECK: [[W1:%.+]] = vector.insertelement [[F1]], [[W0]] +// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] +// CHECK: [[W2:%.+]] = vector.insertelement [[F2]], [[W1]] +// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] +// CHECK: [[W3:%.+]] = vector.insertelement [[F3]], [[W2]] + +// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] +// CHECK: [[W4:%.+]] = vector.insertelement [[F4]], [[W3]] +// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] +// CHECK: [[W5:%.+]] = vector.insertelement [[F5]], [[W4]] +// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] +// CHECK: [[W6:%.+]] = vector.insertelement [[F6]], [[W5]] +// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] +// CHECK: [[W7:%.+]] = vector.insertelement [[F7]], [[W6]] + +// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> +// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] +// CHECK: [[W8:%.+]] = vector.insertelement [[F8]], [[W7]] +// CHECK: return [[W8]] +func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> { + %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32> + return %w : vector<9xf32> +} + +// ----- + +// CHECK-LABEL: func.func @scalar_trunc +// CHECK-SAME: ([[V:%.+]]: f16) +// CHECK: [[C0:%.+]] = arith.constant 0 : index +// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32 +// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ> +// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ> +// CHECK: return [[W]] : f8E5M2FNUZ +func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ { + %w = arith.truncf %v : f16 to f8E5M2FNUZ + return %w : f8E5M2FNUZ +} + +// No 0-D test because arith.truncf hasn't been extended to support it. + +// ----- + +// CHECK-LABEL: func.func @vector_trunc_short +// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> { +// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index +// CHECK: [[V0:%.+]] = vector.extractelement [[V]]{{\[}}[[C0]] : index] +// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32 +// CHECK: [[V1:%.+]] = vector.extractelement [[V]]{{\[}}[[C1]] : index] +// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32 +// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ> +// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ> +// CHECK: return [[W]] : vector<2xf8E5M2FNUZ> +func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2FNUZ> { + %w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2FNUZ> + return %w : vector<2xf8E5M2FNUZ> +} + +// ----- + +// CHECK-LABEL: func.func @vector_trunc_long +// CHECK-SAME: ([[V:%.+]]: vector<9xf32>) +// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ> +// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0] +// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1] +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]} + +// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0] +// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1] +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]} + +// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0] +// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]} +// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]} +// CHECK: return [[W]] +func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> { + %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ> + return %w : vector<9xf8E4M3FNUZ> +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 142224e59a95..5e1ab79962d2 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -1,5 +1,19 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +func.func @mixing_packed_trunc_types(%arg0: f32, %arg1: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> { + // expected-error@+1 {{'amdgpu.packed_trunc_2xfp8' op existing values must have same type as result}} + %ret = amdgpu.packed_trunc_2xfp8 %arg0, undef into %arg1[word 0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + +// ----- + +func.func @mixing_packed_stoch_round_types(%arg0: f32, %arg1: i32, %arg2: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> { + // expected-error@+1 {{'amdgpu.packed_stoch_round_fp8' op existing values must have same type as result}} + %ret = amdgpu.packed_stoch_round_fp8 %arg0 + %arg1 into %arg2[0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + // ----- func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>, diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 4088c6750c91..744a096d757e 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -4,6 +4,27 @@ // Verify the generic form can be parsed. // RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s +// CHECK-LABEL: func @ext_packed_fp8 +// CHECK: amdgpu.ext_packed_fp8 +func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @packed_trunc_2xfp8 +// CHECK: amdgpu.packed_trunc_2xfp8 +func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> { + %ret = amdgpu.packed_trunc_2xfp8 %v1, %v2 into %others[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E5M2FNUZ> +} + +// CHECK-LABEL: func @packed_stoch_round_fp8 +// CHECK: amdgpu.packed_stoch_round_fp8 +func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> { + %ret = amdgpu.packed_stoch_round_fp8 %v1 + %stoch into %others[2] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E5M2FNUZ> +} + // CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1 func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 { // CHECK: amdgpu.raw_buffer_load {indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32 diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 26de6a50fee3..5a14df9ef9f8 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -330,6 +330,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>, llvm.return } +llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 { +// CHECK-LABEL: @rocdl_8bit_floats +// CHECK: rocdl.cvt.f32.bf8 +// CHECK: rocdl.cvt.f32.fp8 +// CHECK: rocdl.cvt.pk.bf8.f32 +// CHECK: rocdl.cvt.pk.fp8.f32 +// CHECK: rocdl.cvt.sr.bf8.f32 +// CHECK: rocdl.cvt.sr.fp8.f32 + %c0 = llvm.mlir.constant(0 : i32) : i32 + %c2 = llvm.mlir.constant(2 : i32) : i32 + %c3 = llvm.mlir.constant(3 : i32) : i32 + %false = llvm.mlir.constant(false) : i1 + %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32 + %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32 + %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 + %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 + %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 + %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32 + llvm.return %source5 : i32 +} + // ----- // expected-error@below {{attribute attached to unexpected op}} diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 777bef8fea58..8b37dfbe3c6e 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -468,6 +468,27 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>, llvm.return %val : i32 } +llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 { +// CHECK-LABEL: @rocdl_8bit_floats +// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0) +// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0) +// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false) +// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false) +// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2) +// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3) + %c0 = llvm.mlir.constant(0 : i32) : i32 + %c2 = llvm.mlir.constant(2 : i32) : i32 + %c3 = llvm.mlir.constant(3 : i32) : i32 + %false = llvm.mlir.constant(false) : i1 + %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32 + %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32 + %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 + %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 + %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 + %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32 + llvm.return %source5 : i32 +} + // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "amdgpu-implicitarg-num-bytes"="56" } // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024" // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"