[mlir][AMDGPU] Add packed 8-bit float conversion ops and lowering

Define operations that wrap the gfx940's new operations for converting
between f32 and registers containing packed sets of four 8-bit floats.

Define rocdl operations for the intrinsics and an AMDGPU dialect
wrapper around them (to account for the fact that MLIR distinguishes
the two float formats at the type level but that the LLVM IR does
not).

Define an ArithToAMDGPU pass, meant to run before conversion to LLVM,
that replaces relevant calls to arith.extf and arith.truncf with the
packed operations in the AMDGPU dialect. Note that the conversion
currently only handles scalars and vectors of rank <= 1, as we do not
have a usecase for multi-dimensional vector support right now.

Reviewed By: jsjodin

Differential Revision: https://reviews.llvm.org/D152457
This commit is contained in:
Krzysztof Drewniak 2023-05-12 15:40:29 +00:00
parent 0eed8ae7d2
commit 2ebd633f14
16 changed files with 914 additions and 4 deletions

View File

@ -0,0 +1,27 @@
//===- ArithToAMDGPU.h - Arith to AMDGPU dialect conversion ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
#include <memory>
namespace mlir {
class RewritePatternSet;
class Pass;
#define GEN_PASS_DECL_ARITHTOAMDGPUCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
namespace arith {
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
} // namespace arith
} // namespace mlir
#endif // MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H

View File

@ -11,6 +11,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"

View File

@ -112,6 +112,21 @@ def ConvertAMDGPUToROCDL : Pass<"convert-amdgpu-to-rocdl"> {
"Chipset that these operations will run on">];
}
//===----------------------------------------------------------------------===//
// ArithToAMDGPU
//===----------------------------------------------------------------------===//
def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
let summary = "Convert Arith operations to AMDGPU-specific implementations";
let description = [{
Convert `arith` operations (currently extf and truncf on 8-bit floats)
to operations in the `amdgpu` dialect. This pass is done in two steps
in order to avoid running a notional arith-to-rocdl and arith-to-llvm
simultaniously.
}];
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
// ArithToLLVM
//===----------------------------------------------------------------------===//

View File

@ -38,6 +38,85 @@ def AMDGPU_Dialect : Dialect {
class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
Op<AMDGPU_Dialect, mnemonic, traits> {}
def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
let description = [{
Extend the value `source[index]` to a 32-bit float and return it.
This rather unusual signature arises from the fact that AMD GPUs cannot
easily work with sub 32-bit quantities, so the compiler intrinsics for
extending 8-bit floats (which are, currently, the only way to work with
this operation) take packed vectors of 4 such floats.
If the passed-in vector has fewer than four elements, or the input is scalar,
the remaining values in the <4 x i8> will be filled with with
undefined values as needed.
}];
let assemblyFormat = [{
attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
}];
}
def AMDGPU_PackedTrunc2xFp8Op :
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
specified) into the low or high word (bottom two or top two) elements
of the returned vector, keeping the other two elements of `existing`
unchanged if present (or undefined if it was not passed in).
The reason for this odd signature is that AMD GPUs cannot easily work with
sub-registers, and so the conversion intrinsics (which are currently the
only way to work with 8-bit float types) take packed vectors of 4 8-bit
values.
}];
let assemblyFormat = [{
attr-dict $sourceA `,` ($sourceB^):(`undef`)?
`into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
`:` type($sourceA) `to` type($res) (`into` type($existing)^)?
}];
let hasVerifier = 1;
}
def AMDGPU_PackedStochRoundFp8Op :
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
the `storeIndex`th element of `res`.
If `existing` is passed in, elements of `res` other than the one at `storeIndex`
are copied from `existing`.
The reason for this odd signature is that AMD GPUs cannot easily work with
sub-registers, and so the conversion intrinsics (which are currently the
only way to work with 8-bit float types) take packed vectors of 4 8-bit
values.
}];
let assemblyFormat = [{
attr-dict $source `+` $stochiasticParam
`into` ($existing^):(`undef`)? `[` $storeIndex `]`
`:` type($source) `to` type($res) (`into` type($existing)^)?
}];
let hasVerifier = 1;
}
/// Raw buffer load
def AMDGPU_RawBufferLoadOp :
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,

View File

@ -116,7 +116,7 @@ class ROCDL_MbcntOp<string mnemonic> :
def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">;
def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">;
def ROCDL_DsSwizzleOp :
def ROCDL_DsSwizzleOp :
ROCDL_Op<"ds_swizzle">,
Results<(outs I32:$res)>,
Arguments<(ins I32:$src,
@ -130,7 +130,7 @@ Arguments<(ins I32:$src,
}];
}
def ROCDL_DsBpermuteOp :
def ROCDL_DsBpermuteOp :
ROCDL_Op<"ds_bpermute">,
Results<(outs I32:$res)>,
Arguments<(ins I32:$index,
@ -525,6 +525,85 @@ def ROCDL_RawBufferAtomicUMinOp :
let hasCustomAssemblyFormat = 1;
}
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtF32Bf8Op :
ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$srcA, I32:$byteSel)> {
let summary = "Convert bf8 to f32";
let description = [{
Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32.
}];
let assemblyFormat = [{
attr-dict $srcA `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtF32Fp8Op :
ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$srcA, I32:$byteSel)> {
let summary = "Convert fp8 to f32";
let description = [{
Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32.
}];
let assemblyFormat = [{
attr-dict $srcA `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtPkBf8F32Op :
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
let summary = "Convert two f32's to bf8";
let description = [{
Convert `srcA` and `srcB` to bf8 and store into the low/high word of
`old`, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res)
}];
}
def ROCDL_CvtPkFp8F32Op :
ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
let summary = "Convert two f32's to fp8";
let description = [{
Convert `srcA` and `srcB` to fp8 and store into the low/high word of
`old`, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res)
}];
}
def ROCDL_CvtSrBf8F32Op :
ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
let summary = "Convert f32 to bf8, stochiastic rounding";
let description = [{
Convert `srcA` to bf8, adding the rounding factor from `srcB`,
and store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtSrFp8F32Op :
ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
let summary = "Convert f32 to fp8, stochiastic rounding";
let description = [{
Convert `srcA` to fp8, adding the rounding factor from `srcB`,
and store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res)
}];
}
//===----------------------------------------------------------------------===//
// ROCDL target attribute.
//===----------------------------------------------------------------------===//
@ -612,5 +691,4 @@ def ROCDL_TargettAttr :
}
}];
}
#endif // ROCDLIR_OPS

View File

@ -10,6 +10,7 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@ -639,6 +640,161 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedStochRoundFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedStochRoundFp8Op op,
PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // end namespace
LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type v4i8 =
getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
Value source = adaptor.getSource();
auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
if (!sourceVecType) {
longVec = rewriter.create<LLVM::InsertElementOp>(
loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
longVec =
rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
}
}
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
if (sourceElemType.isFloat8E5M2FNUZ()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
return success();
}
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Value sourceA = adaptor.getSourceA();
Value sourceB = adaptor.getSourceB();
if (!sourceB)
sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
Value existing = adaptor.getExisting();
if (existing)
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
if (resultElemType.isFloat8E5M2FNUZ())
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
else if (resultElemType.isFloat8E4M3FNUZ())
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Value source = adaptor.getSource();
Value stoch = adaptor.getStochiasticParam();
Value existing = adaptor.getExisting();
if (existing)
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
if (resultElemType.isFloat8E5M2FNUZ())
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
else if (resultElemType.isFloat8E4M3FNUZ())
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
ConvertAMDGPUToROCDLPass() = default;
@ -691,7 +847,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
MFMAOpLowering, WMMAOpLowering>(converter, chipset);
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
chipset);
}
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {

View File

@ -0,0 +1,210 @@
//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
struct ArithToAMDGPUConversionPass final
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
using impl::ArithToAMDGPUConversionPassBase<
ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
void runOnOperation() override;
};
struct ExtfOnFloat8RewritePattern final
: public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
struct TruncfToFloat8RewritePattern final
: public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
};
} // end namespace
static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}
LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
Type inType = op.getIn().getType();
if (auto inVecType = inType.dyn_cast<VectorType>()) {
if (inVecType.isScalable())
return failure();
if (inVecType.getShape().size() > 1)
// Multi-dimensional vectors are currently unsupported.
return failure();
inType = inVecType.getElementType();
}
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
}
void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
if (!in.getType().isa<VectorType>()) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
return rewriter.replaceOp(op, result);
}
VectorType inType = in.getType().cast<VectorType>();
int64_t numElements = inType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result =
rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
if (inType.getShape().empty()) {
Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
result = rewriter.create<vector::InsertElementOp>(loc, scalarExt, zero);
return rewriter.replaceOp(op, result);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, in, i, elemsThisOp, 1);
for (int64_t j = 0; j < elemsThisOp; ++j) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), inSlice, j);
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
result = rewriter.create<vector::InsertElementOp>(
loc, asType, result,
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j));
}
}
rewriter.replaceOp(op, result);
}
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
Type type = value.getType();
if (type.isF32())
return value;
if (type.getIntOrFloatBitWidth() < 32)
return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
if (type.getIntOrFloatBitWidth() > 32)
return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
llvm_unreachable("The only 32-bit float type is f32");
}
LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
if (outVecType.isScalable())
return failure();
if (outVecType.getShape().size() > 1)
// Multi-dimensional vectors are currently unsupported.
return failure();
outType = outVecType.getElementType();
}
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
}
void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType truncResType = VectorType::get(4, outElemType);
if (!in.getType().isa<VectorType>()) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
Value result = rewriter.create<vector::ExtractElementOp>(
loc, asF8s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
return rewriter.replaceOp(op, result);
}
VectorType outType = op.getOut().getType().cast<VectorType>();
int64_t numElements = outType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
if (outType.getShape().empty()) {
Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
result = rewriter.create<vector::InsertElementOp>(loc, scalarTrunc, zero);
return rewriter.replaceOp(op, result);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
for (int64_t j = 0; j < elemsThisOp; j += 2) {
Value elemA = rewriter.create<vector::ExtractElementOp>(
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i + j));
Value asFloatA = castToF32(elemA, loc, rewriter);
Value asFloatB = nullptr;
if (j + 1 < elemsThisOp) {
Value elemB = rewriter.create<vector::ExtractElementOp>(
loc, in,
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j + 1));
asFloatB = castToF32(elemB, loc, rewriter);
}
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
}
if (elemsThisOp < 4)
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
loc, thisResult, 0, elemsThisOp, 1);
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
result, i, 1);
}
rewriter.replaceOp(op, result);
}
void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
patterns.getContext());
}
void ArithToAMDGPUConversionPass::runOnOperation() {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
arith::populateArithToAMDGPUConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}

View File

@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRArithToAMDGPU
ArithToAMDGPU.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToAMDGPU
DEPENDS
MLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
MLIRArithDialect
MLIRVectorDialect
MLIRPass
MLIRTransforms
)

View File

@ -1,6 +1,7 @@
add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)

View File

@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
>();
}
//===----------------------------------------------------------------------===//
// 8-bit float ops
//===----------------------------------------------------------------------===//
LogicalResult PackedTrunc2xFp8Op::verify() {
if (getExisting() && getExisting().getType() != getResult().getType())
return emitOpError("existing values must have same type as result");
return success();
}
LogicalResult PackedStochRoundFp8Op::verify() {
if (getExisting() && getExisting().getType() != getResult().getType())
return emitOpError("existing values must have same type as result");
return success();
}
//===----------------------------------------------------------------------===//
// RawBuffer*Op
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,108 @@
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s
// CHECK-LABEL: func @ext_scalar
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2FNUZ to i8
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
// CHECK: return [[EXT]]
func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
func.return %ret : f32
}
// CHECK-LABEL: func @ext_short_vec
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8>
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
// CHECK: return [[EXT]]
func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
func.return %ret : f32
}
// CHECK-LABEL: func @ext_full_vec(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
func.return %ret : f32
}
// CHECK-LABEL: func @packed_trunc
// CHECK-SAME: ([[V:%.+]]: f32)
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
%ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
func.return %ret : vector<4xf8E4M3FNUZ>
}
// CHECK-LABEL: func @packed_truncx2
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
%ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
func.return %ret : vector<4xf8E4M3FNUZ>
}
// CHECK-LABEL: func @packed_truncx2_into
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
%ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
func.return %ret : vector<4xf8E5M2FNUZ>
}
// CHECK-LABEL: func @packed_stoch_round
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
%ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FNUZ>
func.return %ret : vector<4xf8E4M3FNUZ>
}
// CHECK-LABEL: func @packed_stoch_round_into
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
%ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
func.return %ret : vector<4xf8E5M2FNUZ>
}

View File

@ -0,0 +1,122 @@
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32
// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
// CHECK: return [[W]]
func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
%w = arith.extf %v : f8E5M2FNUZ to f16
return %w : f16
}
// No 0-D test because arith.extf hasn't been extended to support it.
// -----
// CHECK-LABEL: func.func @vector_ext_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
// CHECK: [[W0:%.+]] = vector.insertelement [[EXT0]], [[ZEROES]]{{\[}}[[C0]]
// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
// CHECK: [[W1:%.+]] = vector.insertelement [[EXT1]], [[W0]]{{\[}}[[C1]]
// CHECK: return [[W1]] : vector<2xf64>
func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
%w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64>
return %w : vector<2xf64>
}
// -----
// CHECK-LABEL: func.func @vector_ext_long
// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
// CHECK: [[W0:%.+]] = vector.insertelement [[F0]]
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
// CHECK: [[W1:%.+]] = vector.insertelement [[F1]], [[W0]]
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
// CHECK: [[W2:%.+]] = vector.insertelement [[F2]], [[W1]]
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
// CHECK: [[W3:%.+]] = vector.insertelement [[F3]], [[W2]]
// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
// CHECK: [[W4:%.+]] = vector.insertelement [[F4]], [[W3]]
// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
// CHECK: [[W5:%.+]] = vector.insertelement [[F5]], [[W4]]
// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
// CHECK: [[W6:%.+]] = vector.insertelement [[F6]], [[W5]]
// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
// CHECK: [[W7:%.+]] = vector.insertelement [[F7]], [[W6]]
// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
// CHECK: [[W8:%.+]] = vector.insertelement [[F8]], [[W7]]
// CHECK: return [[W8]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
return %w : vector<9xf32>
}
// -----
// CHECK-LABEL: func.func @scalar_trunc
// CHECK-SAME: ([[V:%.+]]: f16)
// CHECK: [[C0:%.+]] = arith.constant 0 : index
// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ>
// CHECK: return [[W]] : f8E5M2FNUZ
func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
%w = arith.truncf %v : f16 to f8E5M2FNUZ
return %w : f8E5M2FNUZ
}
// No 0-D test because arith.truncf hasn't been extended to support it.
// -----
// CHECK-LABEL: func.func @vector_trunc_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
// CHECK: [[V0:%.+]] = vector.extractelement [[V]]{{\[}}[[C0]] : index]
// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
// CHECK: [[V1:%.+]] = vector.extractelement [[V]]{{\[}}[[C1]] : index]
// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ>
// CHECK: return [[W]] : vector<2xf8E5M2FNUZ>
func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
%w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2FNUZ>
return %w : vector<2xf8E5M2FNUZ>
}
// -----
// CHECK-LABEL: func.func @vector_trunc_long
// CHECK-SAME: ([[V:%.+]]: vector<9xf32>)
// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
// CHECK: return [[W]]
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
%w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
return %w : vector<9xf8E4M3FNUZ>
}

View File

@ -1,5 +1,19 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @mixing_packed_trunc_types(%arg0: f32, %arg1: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> {
// expected-error@+1 {{'amdgpu.packed_trunc_2xfp8' op existing values must have same type as result}}
%ret = amdgpu.packed_trunc_2xfp8 %arg0, undef into %arg1[word 0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ>
func.return %ret : vector<4xf8E4M3FNUZ>
}
// -----
func.func @mixing_packed_stoch_round_types(%arg0: f32, %arg1: i32, %arg2: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> {
// expected-error@+1 {{'amdgpu.packed_stoch_round_fp8' op existing values must have same type as result}}
%ret = amdgpu.packed_stoch_round_fp8 %arg0 + %arg1 into %arg2[0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ>
func.return %ret : vector<4xf8E4M3FNUZ>
}
// -----
func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>,

View File

@ -4,6 +4,27 @@
// Verify the generic form can be parsed.
// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: func @ext_packed_fp8
// CHECK: amdgpu.ext_packed_fp8
func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32
func.return %ret : f32
}
// CHECK-LABEL: func @packed_trunc_2xfp8
// CHECK: amdgpu.packed_trunc_2xfp8
func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
%ret = amdgpu.packed_trunc_2xfp8 %v1, %v2 into %others[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
func.return %ret : vector<4xf8E5M2FNUZ>
}
// CHECK-LABEL: func @packed_stoch_round_fp8
// CHECK: amdgpu.packed_stoch_round_fp8
func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
%ret = amdgpu.packed_stoch_round_fp8 %v1 + %stoch into %others[2] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
func.return %ret : vector<4xf8E5M2FNUZ>
}
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 {
// CHECK: amdgpu.raw_buffer_load {indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32

View File

@ -330,6 +330,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
llvm.return
}
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
// CHECK-LABEL: @rocdl_8bit_floats
// CHECK: rocdl.cvt.f32.bf8
// CHECK: rocdl.cvt.f32.fp8
// CHECK: rocdl.cvt.pk.bf8.f32
// CHECK: rocdl.cvt.pk.fp8.f32
// CHECK: rocdl.cvt.sr.bf8.f32
// CHECK: rocdl.cvt.sr.fp8.f32
%c0 = llvm.mlir.constant(0 : i32) : i32
%c2 = llvm.mlir.constant(2 : i32) : i32
%c3 = llvm.mlir.constant(3 : i32) : i32
%false = llvm.mlir.constant(false) : i1
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
llvm.return %source5 : i32
}
// -----
// expected-error@below {{attribute attached to unexpected op}}

View File

@ -468,6 +468,27 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
llvm.return %val : i32
}
llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
// CHECK-LABEL: @rocdl_8bit_floats
// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
%c0 = llvm.mlir.constant(0 : i32) : i32
%c2 = llvm.mlir.constant(2 : i32) : i32
%c3 = llvm.mlir.constant(3 : i32) : i32
%false = llvm.mlir.constant(false) : i1
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
llvm.return %source5 : i32
}
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "amdgpu-implicitarg-num-bytes"="56" }
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"