diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h new file mode 100644 index 000000000000..7f445fee5ba6 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h @@ -0,0 +1,27 @@ +//===- ArithToAMDGPU.h - Arith to AMDGPU dialect conversion ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H +#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H + +#include + +namespace mlir { + +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_ARITHTOAMDGPUCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" + +namespace arith { +void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns); +} // namespace arith +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 41806004fc1d..e714f5070f23 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -11,6 +11,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 9b7848d9288b..38b05c792d40 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -112,6 +112,21 @@ def ConvertAMDGPUToROCDL : Pass<"convert-amdgpu-to-rocdl"> { "Chipset that these operations will run on">]; } +//===----------------------------------------------------------------------===// +// ArithToAMDGPU +//===----------------------------------------------------------------------===// +def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> { + let summary = "Convert Arith operations to AMDGPU-specific implementations"; + let description = [{ + Convert `arith` operations (currently extf and truncf on 8-bit floats) + to operations in the `amdgpu` dialect. This pass is done in two steps + in order to avoid running a notional arith-to-rocdl and arith-to-llvm + simultaniously. + }]; + + let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"]; +} + //===----------------------------------------------------------------------===// // ArithToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 6d788e3a9701..ffb302fcedd7 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -38,6 +38,85 @@ def AMDGPU_Dialect : Dialect { class AMDGPU_Op traits = []> : Op {} +def AMDGPU_ExtPackedFp8Op : + AMDGPU_Op<"ext_packed_fp8", [Pure]>, + Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, + VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source, + ConfinedAttr]>:$index)>, + Results<(outs F32:$res)> { + let summary = "Extend one of a vector of packed fp8 values to a float"; + let description = [{ + Extend the value `source[index]` to a 32-bit float and return it. + + This rather unusual signature arises from the fact that AMD GPUs cannot + easily work with sub 32-bit quantities, so the compiler intrinsics for + extending 8-bit floats (which are, currently, the only way to work with + this operation) take packed vectors of 4 such floats. + + If the passed-in vector has fewer than four elements, or the input is scalar, + the remaining values in the <4 x i8> will be filled with with + undefined values as needed. + }]; + let assemblyFormat = [{ + attr-dict $source `[` $index `]` `:` type($source) `to` type($res) + }]; +} + +def AMDGPU_PackedTrunc2xFp8Op : + AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>, + Arguments<(ins F32:$sourceA, + Optional:$sourceB, + ConfinedAttr]>:$wordIndex, + Optional>:$existing)>, + Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> { + let summary = "Round two floats into a packed vector of 8-bit floats"; + let description = [{ + Round the inputs `sourceA` and `sourceB` (which is undefined if not + specified) into the low or high word (bottom two or top two) elements + of the returned vector, keeping the other two elements of `existing` + unchanged if present (or undefined if it was not passed in). + + The reason for this odd signature is that AMD GPUs cannot easily work with + sub-registers, and so the conversion intrinsics (which are currently the + only way to work with 8-bit float types) take packed vectors of 4 8-bit + values. + }]; + let assemblyFormat = [{ + attr-dict $sourceA `,` ($sourceB^):(`undef`)? + `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]` + `:` type($sourceA) `to` type($res) (`into` type($existing)^)? + }]; + let hasVerifier = 1; +} + +def AMDGPU_PackedStochRoundFp8Op : + AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>, + Arguments<(ins F32:$source, + I32:$stochiasticParam, + ConfinedAttr]>:$storeIndex, + Optional>:$existing)>, + Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> { + let summary = "Round float stochiastically into a packed vector of 8-bit floats"; + let description = [{ + Round the input `source`, adding in `stochiasticParam`, and place it into + the `storeIndex`th element of `res`. + + If `existing` is passed in, elements of `res` other than the one at `storeIndex` + are copied from `existing`. + + The reason for this odd signature is that AMD GPUs cannot easily work with + sub-registers, and so the conversion intrinsics (which are currently the + only way to work with 8-bit float types) take packed vectors of 4 8-bit + values. + }]; + let assemblyFormat = [{ + attr-dict $source `+` $stochiasticParam + `into` ($existing^):(`undef`)? `[` $storeIndex `]` + `:` type($source) `to` type($res) (`into` type($existing)^)? + }]; + let hasVerifier = 1; +} + /// Raw buffer load def AMDGPU_RawBufferLoadOp : AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>, diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 08d36397dc31..6c6419bf238b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -116,7 +116,7 @@ class ROCDL_MbcntOp : def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">; def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">; -def ROCDL_DsSwizzleOp : +def ROCDL_DsSwizzleOp : ROCDL_Op<"ds_swizzle">, Results<(outs I32:$res)>, Arguments<(ins I32:$src, @@ -130,7 +130,7 @@ Arguments<(ins I32:$src, }]; } -def ROCDL_DsBpermuteOp : +def ROCDL_DsBpermuteOp : ROCDL_Op<"ds_bpermute">, Results<(outs I32:$res)>, Arguments<(ins I32:$index, @@ -525,6 +525,85 @@ def ROCDL_RawBufferAtomicUMinOp : let hasCustomAssemblyFormat = 1; } +//===---------------------------------------------------------------------===// +// 8-bit float intrinsics +//===---------------------------------------------------------------------===// +def ROCDL_CvtF32Bf8Op : + ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>, + Arguments<(ins I32:$srcA, I32:$byteSel)> { + let summary = "Convert bf8 to f32"; + let description = [{ + Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32. + }]; + let assemblyFormat = [{ + attr-dict $srcA `[` $byteSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtF32Fp8Op : + ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>, + Arguments<(ins I32:$srcA, I32:$byteSel)> { + let summary = "Convert fp8 to f32"; + let description = [{ + Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32. + }]; + let assemblyFormat = [{ + attr-dict $srcA `[` $byteSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtPkBf8F32Op : + ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>, + Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> { + let summary = "Convert two f32's to bf8"; + let description = [{ + Convert `srcA` and `srcB` to bf8 and store into the low/high word of + `old`, preserving the other word. + }]; + let assemblyFormat = [{ + attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtPkFp8F32Op : + ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>, + Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> { + let summary = "Convert two f32's to fp8"; + let description = [{ + Convert `srcA` and `srcB` to fp8 and store into the low/high word of + `old`, preserving the other word. + }]; + let assemblyFormat = [{ + attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtSrBf8F32Op : + ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>, + Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> { + let summary = "Convert f32 to bf8, stochiastic rounding"; + let description = [{ + Convert `srcA` to bf8, adding the rounding factor from `srcB`, + and store into the `byteSel`th byte of `old`, preserving the others. + }]; + let assemblyFormat = [{ + attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res) + }]; +} + +def ROCDL_CvtSrFp8F32Op : + ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>, + Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> { + let summary = "Convert f32 to fp8, stochiastic rounding"; + let description = [{ + Convert `srcA` to fp8, adding the rounding factor from `srcB`, + and store into the `byteSel`th byte of `old`, preserving the others. + }]; + let assemblyFormat = [{ + attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res) + }]; +} + //===----------------------------------------------------------------------===// // ROCDL target attribute. //===----------------------------------------------------------------------===// @@ -612,5 +691,4 @@ def ROCDL_TargettAttr : } }]; } - #endif // ROCDLIR_OPS diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index eeed04049668..9ed312cef744 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -639,6 +640,161 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { } }; +namespace { +struct ExtPackedFp8OpLowering final + : public ConvertOpToLLVMPattern { + ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct PackedTrunc2xFp8OpLowering final + : public ConvertOpToLLVMPattern { + PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct PackedStochRoundFp8OpLowering final + : public ConvertOpToLLVMPattern { + PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(PackedStochRoundFp8Op op, + PackedStochRoundFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // end namespace + +LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( + ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + return rewriter.notifyMatchFailure( + loc, "Fp8 conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + Type v4i8 = + getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); + Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); + Type f32 = getTypeConverter()->convertType(op.getResult().getType()); + + Value source = adaptor.getSource(); + auto sourceVecType = op.getSource().getType().dyn_cast(); + Type sourceElemType = getElementTypeOrSelf(op.getSource()); + // Extend to a v4i8 + if (!sourceVecType || sourceVecType.getNumElements() < 4) { + Value longVec = rewriter.create(loc, v4i8); + if (!sourceVecType) { + longVec = rewriter.create( + loc, longVec, source, createI32Constant(rewriter, loc, 0)); + } else { + for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { + Value idx = createI32Constant(rewriter, loc, i); + Value elem = rewriter.create(loc, source, idx); + longVec = + rewriter.create(loc, longVec, elem, idx); + } + } + source = longVec; + } + Value i32Source = rewriter.create(loc, i32, source); + Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); + if (sourceElemType.isFloat8E5M2FNUZ()) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); + } else if (sourceElemType.isFloat8E4M3FNUZ()) { + rewriter.replaceOpWithNewOp(op, f32, i32Source, + wordSel); + } + return success(); +} + +LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( + PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + return rewriter.notifyMatchFailure( + loc, "Fp8 conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); + + Type resultType = op.getResult().getType(); + Type resultElemType = getElementTypeOrSelf(resultType); + + Value sourceA = adaptor.getSourceA(); + Value sourceB = adaptor.getSourceB(); + if (!sourceB) + sourceB = rewriter.create(loc, sourceA.getType()); + Value existing = adaptor.getExisting(); + if (existing) + existing = rewriter.create(loc, i32, existing); + else + existing = rewriter.create(loc, i32); + Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); + + Value result; + if (resultElemType.isFloat8E5M2FNUZ()) + result = rewriter.create(loc, i32, sourceA, sourceB, + existing, wordSel); + else if (resultElemType.isFloat8E4M3FNUZ()) + result = rewriter.create(loc, i32, sourceA, sourceB, + existing, wordSel); + + result = rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(resultType), result); + return success(); +} + +LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( + PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + return rewriter.notifyMatchFailure( + loc, "Fp8 conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); + + Type resultType = op.getResult().getType(); + Type resultElemType = getElementTypeOrSelf(resultType); + + Value source = adaptor.getSource(); + Value stoch = adaptor.getStochiasticParam(); + Value existing = adaptor.getExisting(); + if (existing) + existing = rewriter.create(loc, i32, existing); + else + existing = rewriter.create(loc, i32); + Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); + + Value result; + if (resultElemType.isFloat8E5M2FNUZ()) + result = rewriter.create(loc, i32, source, stoch, + existing, byteSel); + else if (resultElemType.isFloat8E4M3FNUZ()) + result = rewriter.create(loc, i32, source, stoch, + existing, byteSel); + + result = rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(resultType), result); + return success(); +} + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLBase { ConvertAMDGPUToROCDLPass() = default; @@ -691,7 +847,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicUminOp>, RawBufferOpLowering, - MFMAOpLowering, WMMAOpLowering>(converter, chipset); + MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, + PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter, + chipset); } std::unique_ptr mlir::createConvertAMDGPUToROCDLPass() { diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp new file mode 100644 index 000000000000..7785405eae67 --- /dev/null +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -0,0 +1,210 @@ +//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ArithToAMDGPUConversionPass final + : impl::ArithToAMDGPUConversionPassBase { + using impl::ArithToAMDGPUConversionPassBase< + ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; + + void runOnOperation() override; +}; + +struct ExtfOnFloat8RewritePattern final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(arith::ExtFOp op) const override; + void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; +}; + +struct TruncfToFloat8RewritePattern final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(arith::TruncFOp op) const override; + void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; +}; +} // end namespace + +static Value castF32To(Type elementType, Value f32, Location loc, + PatternRewriter &rewriter) { + if (elementType.isF32()) + return f32; + if (elementType.getIntOrFloatBitWidth() < 32) + return rewriter.create(loc, elementType, f32); + if (elementType.getIntOrFloatBitWidth() > 32) + return rewriter.create(loc, elementType, f32); + llvm_unreachable("The only 32-bit float type is f32"); +} + +LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const { + Type inType = op.getIn().getType(); + if (auto inVecType = inType.dyn_cast()) { + if (inVecType.isScalable()) + return failure(); + if (inVecType.getShape().size() > 1) + // Multi-dimensional vectors are currently unsupported. + return failure(); + inType = inVecType.getElementType(); + } + return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ()); +} + +void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value in = op.getIn(); + Type outElemType = getElementTypeOrSelf(op.getOut().getType()); + if (!in.getType().isa()) { + Value asFloat = rewriter.create( + loc, rewriter.getF32Type(), in, 0); + Value result = castF32To(outElemType, asFloat, loc, rewriter); + return rewriter.replaceOp(op, result); + } + VectorType inType = in.getType().cast(); + int64_t numElements = inType.getNumElements(); + Value zero = rewriter.createOrFold( + loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value result = + rewriter.createOrFold(loc, op.getOut().getType(), zero); + if (inType.getShape().empty()) { + Value scalarIn = rewriter.create(loc, in); + // Recurse to send the 0-D vector case to the 1-D vector case + Value scalarExt = + rewriter.create(loc, outElemType, scalarIn); + result = rewriter.create(loc, scalarExt, zero); + return rewriter.replaceOp(op, result); + } + for (int64_t i = 0; i < numElements; i += 4) { + int64_t elemsThisOp = std::min(numElements, i + 4) - i; + Value inSlice = rewriter.create( + loc, in, i, elemsThisOp, 1); + for (int64_t j = 0; j < elemsThisOp; ++j) { + Value asFloat = rewriter.create( + loc, rewriter.getF32Type(), inSlice, j); + Value asType = castF32To(outElemType, asFloat, loc, rewriter); + result = rewriter.create( + loc, asType, result, + rewriter.createOrFold(loc, i + j)); + } + } + rewriter.replaceOp(op, result); +} + +static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { + Type type = value.getType(); + if (type.isF32()) + return value; + if (type.getIntOrFloatBitWidth() < 32) + return rewriter.create(loc, rewriter.getF32Type(), value); + if (type.getIntOrFloatBitWidth() > 32) + return rewriter.create(loc, rewriter.getF32Type(), value); + llvm_unreachable("The only 32-bit float type is f32"); +} + +LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const { + Type outType = op.getOut().getType(); + if (auto outVecType = outType.dyn_cast()) { + if (outVecType.isScalable()) + return failure(); + if (outVecType.getShape().size() > 1) + // Multi-dimensional vectors are currently unsupported. + return failure(); + outType = outVecType.getElementType(); + } + return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()); +} + +void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value in = op.getIn(); + Type outElemType = getElementTypeOrSelf(op.getOut().getType()); + VectorType truncResType = VectorType::get(4, outElemType); + if (!in.getType().isa()) { + Value asFloat = castToF32(in, loc, rewriter); + Value asF8s = rewriter.create( + loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, + /*existing=*/nullptr); + Value result = rewriter.create( + loc, asF8s, rewriter.createOrFold(loc, 0)); + return rewriter.replaceOp(op, result); + } + VectorType outType = op.getOut().getType().cast(); + int64_t numElements = outType.getNumElements(); + Value zero = rewriter.createOrFold( + loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value result = rewriter.createOrFold(loc, outType, zero); + if (outType.getShape().empty()) { + Value scalarIn = rewriter.create(loc, in); + // Recurse to send the 0-D vector case to the 1-D vector case + Value scalarTrunc = + rewriter.create(loc, outElemType, scalarIn); + result = rewriter.create(loc, scalarTrunc, zero); + return rewriter.replaceOp(op, result); + } + + for (int64_t i = 0; i < numElements; i += 4) { + int64_t elemsThisOp = std::min(numElements, i + 4) - i; + Value thisResult = nullptr; + for (int64_t j = 0; j < elemsThisOp; j += 2) { + Value elemA = rewriter.create( + loc, in, rewriter.create(loc, i + j)); + Value asFloatA = castToF32(elemA, loc, rewriter); + Value asFloatB = nullptr; + if (j + 1 < elemsThisOp) { + Value elemB = rewriter.create( + loc, in, + rewriter.createOrFold(loc, i + j + 1)); + asFloatB = castToF32(elemB, loc, rewriter); + } + thisResult = rewriter.create( + loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); + } + if (elemsThisOp < 4) + thisResult = rewriter.create( + loc, thisResult, 0, elemsThisOp, 1); + result = rewriter.create(loc, thisResult, + result, i, 1); + } + rewriter.replaceOp(op, result); +} + +void mlir::arith::populateArithToAMDGPUConversionPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void ArithToAMDGPUConversionPass::runOnOperation() { + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + arith::populateArithToAMDGPUConversionPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt new file mode 100644 index 000000000000..359015b6f86a --- /dev/null +++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArithToAMDGPU + ArithToAMDGPU.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToAMDGPU + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRAMDGPUDialect + MLIRArithDialect + MLIRVectorDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 660e48768c4f..35790254be13 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) +add_subdirectory(ArithToAMDGPU) add_subdirectory(ArithToLLVM) add_subdirectory(ArithToSPIRV) add_subdirectory(ArmNeon2dToIntr) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index ac34acc83074..2575ad498481 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// 8-bit float ops +//===----------------------------------------------------------------------===// +LogicalResult PackedTrunc2xFp8Op::verify() { + if (getExisting() && getExisting().getType() != getResult().getType()) + return emitOpError("existing values must have same type as result"); + return success(); +} + +LogicalResult PackedStochRoundFp8Op::verify() { + if (getExisting() && getExisting().getType() != getResult().getType()) + return emitOpError("existing values must have same type as result"); + return success(); +} + //===----------------------------------------------------------------------===// // RawBuffer*Op //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir new file mode 100644 index 000000000000..7818a525d17b --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s + +// CHECK-LABEL: func @ext_scalar +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2FNUZ to i8 +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 +// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: return [[EXT]] +func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_short_vec +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8> +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8> +// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8> +// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> +// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: return [[EXT]] +func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_full_vec( +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 +// CHECK: return [[EXT]] : f32 + +func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @packed_trunc +// CHECK-SAME: ([[V:%.+]]: f32) +// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32 +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> +func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> { + %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + +// CHECK-LABEL: func @packed_truncx2 +// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32) +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> +func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> { + %ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + +// CHECK-LABEL: func @packed_truncx2_into +// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>) +// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8> +// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 +// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ> +func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> { + %ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E5M2FNUZ> +} + +// CHECK-LABEL: func @packed_stoch_round +// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32) +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> +func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> { + %ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + +// CHECK-LABEL: func @packed_stoch_round_into +// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>) +// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8> +// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 +// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ> +func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> { + %ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E5M2FNUZ> +} diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir new file mode 100644 index 000000000000..a6c11d022e2c --- /dev/null +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir @@ -0,0 +1,122 @@ +// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s + +// CHECK-LABEL: func.func @scalar_ext +// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ) +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32 +// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16 +// CHECK: return [[W]] +func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 { + %w = arith.extf %v : f8E5M2FNUZ to f16 + return %w : f16 +} + +// No 0-D test because arith.extf hasn't been extended to support it. + +// ----- + +// CHECK-LABEL: func.func @vector_ext_short +// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>) +// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64> +// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index +// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32 +// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64 +// CHECK: [[W0:%.+]] = vector.insertelement [[EXT0]], [[ZEROES]]{{\[}}[[C0]] +// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32 +// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]] +// CHECK: [[W1:%.+]] = vector.insertelement [[EXT1]], [[W0]]{{\[}}[[C1]] +// CHECK: return [[W1]] : vector<2xf64> + +func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> { + %w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64> + return %w : vector<2xf64> +} + +// ----- + +// CHECK-LABEL: func.func @vector_ext_long +// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>) +// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} +// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] +// CHECK: [[W0:%.+]] = vector.insertelement [[F0]] +// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] +// CHECK: [[W1:%.+]] = vector.insertelement [[F1]], [[W0]] +// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] +// CHECK: [[W2:%.+]] = vector.insertelement [[F2]], [[W1]] +// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] +// CHECK: [[W3:%.+]] = vector.insertelement [[F3]], [[W2]] + +// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] +// CHECK: [[W4:%.+]] = vector.insertelement [[F4]], [[W3]] +// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] +// CHECK: [[W5:%.+]] = vector.insertelement [[F5]], [[W4]] +// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] +// CHECK: [[W6:%.+]] = vector.insertelement [[F6]], [[W5]] +// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] +// CHECK: [[W7:%.+]] = vector.insertelement [[F7]], [[W6]] + +// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ> +// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] +// CHECK: [[W8:%.+]] = vector.insertelement [[F8]], [[W7]] +// CHECK: return [[W8]] +func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> { + %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32> + return %w : vector<9xf32> +} + +// ----- + +// CHECK-LABEL: func.func @scalar_trunc +// CHECK-SAME: ([[V:%.+]]: f16) +// CHECK: [[C0:%.+]] = arith.constant 0 : index +// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32 +// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ> +// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ> +// CHECK: return [[W]] : f8E5M2FNUZ +func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ { + %w = arith.truncf %v : f16 to f8E5M2FNUZ + return %w : f8E5M2FNUZ +} + +// No 0-D test because arith.truncf hasn't been extended to support it. + +// ----- + +// CHECK-LABEL: func.func @vector_trunc_short +// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> { +// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index +// CHECK: [[V0:%.+]] = vector.extractelement [[V]]{{\[}}[[C0]] : index] +// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32 +// CHECK: [[V1:%.+]] = vector.extractelement [[V]]{{\[}}[[C1]] : index] +// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32 +// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ> +// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ> +// CHECK: return [[W]] : vector<2xf8E5M2FNUZ> +func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2FNUZ> { + %w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2FNUZ> + return %w : vector<2xf8E5M2FNUZ> +} + +// ----- + +// CHECK-LABEL: func.func @vector_trunc_long +// CHECK-SAME: ([[V:%.+]]: vector<9xf32>) +// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ> +// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0] +// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1] +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]} + +// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0] +// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1] +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]} + +// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0] +// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]} +// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]} +// CHECK: return [[W]] +func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> { + %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ> + return %w : vector<9xf8E4M3FNUZ> +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 142224e59a95..5e1ab79962d2 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -1,5 +1,19 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +func.func @mixing_packed_trunc_types(%arg0: f32, %arg1: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> { + // expected-error@+1 {{'amdgpu.packed_trunc_2xfp8' op existing values must have same type as result}} + %ret = amdgpu.packed_trunc_2xfp8 %arg0, undef into %arg1[word 0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + +// ----- + +func.func @mixing_packed_stoch_round_types(%arg0: f32, %arg1: i32, %arg2: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> { + // expected-error@+1 {{'amdgpu.packed_stoch_round_fp8' op existing values must have same type as result}} + %ret = amdgpu.packed_stoch_round_fp8 %arg0 + %arg1 into %arg2[0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E4M3FNUZ> +} + // ----- func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>, diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 4088c6750c91..744a096d757e 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -4,6 +4,27 @@ // Verify the generic form can be parsed. // RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s +// CHECK-LABEL: func @ext_packed_fp8 +// CHECK: amdgpu.ext_packed_fp8 +func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @packed_trunc_2xfp8 +// CHECK: amdgpu.packed_trunc_2xfp8 +func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> { + %ret = amdgpu.packed_trunc_2xfp8 %v1, %v2 into %others[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E5M2FNUZ> +} + +// CHECK-LABEL: func @packed_stoch_round_fp8 +// CHECK: amdgpu.packed_stoch_round_fp8 +func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> { + %ret = amdgpu.packed_stoch_round_fp8 %v1 + %stoch into %others[2] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ> + func.return %ret : vector<4xf8E5M2FNUZ> +} + // CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1 func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 { // CHECK: amdgpu.raw_buffer_load {indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32 diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 26de6a50fee3..5a14df9ef9f8 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -330,6 +330,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>, llvm.return } +llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 { +// CHECK-LABEL: @rocdl_8bit_floats +// CHECK: rocdl.cvt.f32.bf8 +// CHECK: rocdl.cvt.f32.fp8 +// CHECK: rocdl.cvt.pk.bf8.f32 +// CHECK: rocdl.cvt.pk.fp8.f32 +// CHECK: rocdl.cvt.sr.bf8.f32 +// CHECK: rocdl.cvt.sr.fp8.f32 + %c0 = llvm.mlir.constant(0 : i32) : i32 + %c2 = llvm.mlir.constant(2 : i32) : i32 + %c3 = llvm.mlir.constant(3 : i32) : i32 + %false = llvm.mlir.constant(false) : i1 + %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32 + %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32 + %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 + %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 + %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 + %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32 + llvm.return %source5 : i32 +} + // ----- // expected-error@below {{attribute attached to unexpected op}} diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 777bef8fea58..8b37dfbe3c6e 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -468,6 +468,27 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>, llvm.return %val : i32 } +llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 { +// CHECK-LABEL: @rocdl_8bit_floats +// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0) +// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0) +// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false) +// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false) +// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2) +// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3) + %c0 = llvm.mlir.constant(0 : i32) : i32 + %c2 = llvm.mlir.constant(2 : i32) : i32 + %c3 = llvm.mlir.constant(3 : i32) : i32 + %false = llvm.mlir.constant(false) : i1 + %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32 + %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32 + %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 + %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 + %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 + %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32 + llvm.return %source5 : i32 +} + // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "amdgpu-implicitarg-num-bytes"="56" } // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024" // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"