mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-23 05:40:09 +00:00
[mlir][amdgpu] Shared memory access optimization pass (#75627)
It implements transformation to optimize accesses to shared memory. Reference: https://reviews.llvm.org/D127457 _This change adds a transformation and pass to the NvGPU dialect that attempts to optimize reads/writes from a memref representing GPU shared memory in order to avoid bank conflicts. Given a value representing a shared memory memref, it traverses all reads/writes within the parent op and, subject to suitable conditions, rewrites all last dimension index values such that element locations in the final (col) dimension are given by newColIdx = col % vecSize + perm[row](col / vecSize, row) where perm is a permutation function indexed by row and vecSize is the vector access size in elements (currently assumes 128bit vectorized accesses, but this can be made a parameter). This specific transformation can help optimize typical distributed & vectorized accesses common to loading matrix multiplication operands to/from shared memory._
This commit is contained in:
parent
30aa9fb4c1
commit
b7360fbe8c
@ -29,6 +29,23 @@ def AMDGPU_Dialect : Dialect {
|
||||
"gpu::GPUDialect"
|
||||
];
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if the given MemRefType has an integer address
|
||||
/// space that matches the ROCDL shared memory address space or
|
||||
/// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
|
||||
static bool hasSharedMemoryAddressSpace(MemRefType type);
|
||||
|
||||
/// Return true if the given Attribute has an integer address
|
||||
/// space that matches the ROCDL shared memory address space or
|
||||
/// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
|
||||
static bool isSharedMemoryAddressSpace(Attribute type);
|
||||
|
||||
/// Defines the MemRef memory space attribute numeric value that indicates
|
||||
/// a memref is located in shared memory. This should correspond to the
|
||||
/// value used in ROCDL.
|
||||
static constexpr unsigned kSharedMemoryAddressSpace = 3;
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -20,7 +20,8 @@ namespace mlir {
|
||||
class ConversionTarget;
|
||||
namespace amdgpu {
|
||||
|
||||
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
|
||||
#define GEN_PASS_DECL
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
|
||||
|
||||
|
@ -30,4 +30,17 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
|
||||
"Chipset that these operations will run on">];
|
||||
}
|
||||
|
||||
def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
|
||||
let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
|
||||
let description = [{
|
||||
This pass adds a transformation and pass to the AMDGPU dialect that
|
||||
attempts to optimize reads/writes from a memref representing GPU shared
|
||||
memory in order to avoid bank conflicts.
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"memref::MemRefDialect", "vector::VectorDialect"
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
|
||||
|
54
mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
Normal file
54
mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
Normal file
@ -0,0 +1,54 @@
|
||||
//===- Transforms.h - AMDGPU Dialect transformations --------------*-
|
||||
// C++-*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares functions that assist transformations for the amdgpu
|
||||
// dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
|
||||
#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
class RewriterBase;
|
||||
|
||||
namespace amdgpu {
|
||||
|
||||
///
|
||||
/// Passes
|
||||
///
|
||||
|
||||
/// Optimizes vectorized accesses to a shared memory buffer specified by
|
||||
/// memrefValue. This transformation assumes the following:
|
||||
/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
|
||||
/// 2) The function will fail precondition checks if any subviews are
|
||||
/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur
|
||||
/// through `memrefValue` directly.
|
||||
///
|
||||
/// Shared memory bank conflicts occur when multiple threads attempt to read or
|
||||
/// write locations assigned to the same shared memory bank. For `2^N` byte
|
||||
/// vectorized accesses, we need to be concerned with conflicts among threads
|
||||
/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation
|
||||
/// changes any indexed memory access (vector.load, memref.load, etc)
|
||||
/// such that the final dimension's index value is permuted such that
|
||||
/// `newColIndex = oldColIndex % vectorSize +
|
||||
/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the
|
||||
/// index for the second-to last dimension and `perm[rowIndex]` is a permutation
|
||||
/// function that depends on the row Index. The permutation function is chosen
|
||||
/// to ensure that sequential distributed+vectorized reads/writes down a single
|
||||
/// dimension of the memref have minimal conflicts.
|
||||
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
|
||||
Value memrefValue);
|
||||
|
||||
} // namespace amdgpu
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
|
24
mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
Normal file
24
mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
Normal file
@ -0,0 +1,24 @@
|
||||
//===- Utils.h - Transform utilities -----------------------------*- C++-*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace amdgpu {
|
||||
|
||||
/// Get and set the indices that the given load/store operation is operating on.
|
||||
/// Preconditions:
|
||||
/// - The Op must have memory affects
|
||||
/// - Considers memref::LoadOp, vector::LoadOp, vector::TransferReadOp
|
||||
/// - Considers memref::StoreOp, vector::StoreOp, vector::TransferWriteOp
|
||||
/// - Excludes subview op
|
||||
std::optional<Operation::operand_range> getIndices(Operation *op);
|
||||
void setIndices(Operation *op, ArrayRef<Value> indices);
|
||||
|
||||
} // namespace amdgpu
|
||||
} // namespace mlir
|
@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
|
||||
>();
|
||||
}
|
||||
|
||||
bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
|
||||
if (!memorySpace)
|
||||
return false;
|
||||
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
|
||||
return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
|
||||
if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
|
||||
return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
|
||||
Attribute memorySpace = type.getMemorySpace();
|
||||
return isSharedMemoryAddressSpace(memorySpace);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// 8-bit float ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1,5 +1,7 @@
|
||||
add_mlir_dialect_library(MLIRAMDGPUTransforms
|
||||
EmulateAtomics.cpp
|
||||
OptimizeSharedMemory.cpp
|
||||
Utils.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
|
||||
|
243
mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
Normal file
243
mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
Normal file
@ -0,0 +1,243 @@
|
||||
//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements transforms to optimize accesses to shared memory.
|
||||
// It is inspired by
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace amdgpu {
|
||||
#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
|
||||
} // namespace amdgpu
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::amdgpu;
|
||||
|
||||
/// The size of a shared memory line according to AMD documentation.
|
||||
/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
|
||||
constexpr int64_t kSharedMemoryLineSizeBytes = 64;
|
||||
/// We optimize for 64bit accesses, but this can be made an argument in the
|
||||
/// future.
|
||||
constexpr int64_t kDefaultVectorSizeBits = 64;
|
||||
|
||||
/// Uses `srcIndexValue` to permute `tgtIndexValue` via
|
||||
/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
|
||||
/// floordiv(tgtIdxVal,vectorSize)))
|
||||
/// + tgtIdxVal % vectorSize`
|
||||
/// This is done using an optimized sequence of `arith` operations.
|
||||
static Value permuteVectorOffset(OpBuilder &b, Location loc,
|
||||
ArrayRef<Value> indices, MemRefType memrefTy,
|
||||
int64_t srcDim, int64_t tgtDim) {
|
||||
// Adjust the src index to change how often the permutation changes
|
||||
// if necessary.
|
||||
Value src = indices[srcDim];
|
||||
|
||||
// We only want to permute every N iterations of the target dim where N is
|
||||
// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
|
||||
const int64_t permuteEveryN = std::max<int64_t>(
|
||||
1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
|
||||
memrefTy.getElementTypeBitWidth()) /
|
||||
8));
|
||||
|
||||
// clang-format off
|
||||
// Index bit representation (b0 = least significant bit) for dim(1)
|
||||
// of a `memref<?x?xDT>` is as follows:
|
||||
// N := log2(128/elementSizeBits)
|
||||
// M := log2(dimSize(1))
|
||||
// then
|
||||
// bits[0:N] = sub-vector element offset
|
||||
// bits[N:M] = vector index
|
||||
// clang-format on
|
||||
int64_t n =
|
||||
llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
|
||||
int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
|
||||
|
||||
// Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
|
||||
int64_t mask = (1LL << (m - n)) - 1;
|
||||
if (permuteEveryN > 1)
|
||||
mask = mask << llvm::Log2_64(permuteEveryN);
|
||||
Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
|
||||
srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
|
||||
|
||||
// Use the src bits to permute the target bits b[N:M] containing the
|
||||
// vector offset.
|
||||
if (permuteEveryN > 1) {
|
||||
int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
|
||||
if (shlBits > 0) {
|
||||
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
|
||||
srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
|
||||
} else if (shlBits < 0) {
|
||||
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
|
||||
srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
|
||||
}
|
||||
} else {
|
||||
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
|
||||
srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
|
||||
}
|
||||
|
||||
Value permutedVectorIdx =
|
||||
b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
|
||||
return permutedVectorIdx;
|
||||
}
|
||||
|
||||
static void transformIndices(OpBuilder &builder, Location loc,
|
||||
SmallVector<Value, 4> &indices,
|
||||
MemRefType memrefTy, int64_t srcDim,
|
||||
int64_t tgtDim) {
|
||||
indices[tgtDim] =
|
||||
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
|
||||
}
|
||||
|
||||
/// Return all operations within `parentOp` that read from or write to
|
||||
/// `shmMemRef`.
|
||||
static LogicalResult
|
||||
getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
|
||||
SmallVector<Operation *, 16> &readOps,
|
||||
SmallVector<Operation *, 16> &writeOps) {
|
||||
parentOp->walk([&](Operation *op) {
|
||||
MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
|
||||
if (!iface)
|
||||
return;
|
||||
std::optional<MemoryEffects::EffectInstance> effect =
|
||||
iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
|
||||
if (effect) {
|
||||
readOps.push_back(op);
|
||||
return;
|
||||
}
|
||||
effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
|
||||
if (effect)
|
||||
writeOps.push_back(op);
|
||||
});
|
||||
|
||||
// Restrict to a supported set of ops. We also require at least 2D access,
|
||||
// although this could be relaxed.
|
||||
if (llvm::any_of(readOps, [](Operation *op) {
|
||||
return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
|
||||
op) ||
|
||||
amdgpu::getIndices(op)->size() < 2;
|
||||
}))
|
||||
return failure();
|
||||
if (llvm::any_of(writeOps, [](Operation *op) {
|
||||
return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
|
||||
op) ||
|
||||
amdgpu::getIndices(op)->size() < 2;
|
||||
}))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
|
||||
Value memrefValue) {
|
||||
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
|
||||
if (!memRefType ||
|
||||
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
|
||||
return failure();
|
||||
|
||||
// Abort if the given value has any sub-views; we do not do any alias
|
||||
// analysis.
|
||||
bool hasSubView = false;
|
||||
parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
|
||||
if (hasSubView)
|
||||
return failure();
|
||||
|
||||
// Check if this is necessary given the assumption of 128b accesses:
|
||||
// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
|
||||
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
|
||||
const int64_t rowsPerLine =
|
||||
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
|
||||
rowSize;
|
||||
const int64_t threadGroupSize =
|
||||
1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
|
||||
if (rowsPerLine >= threadGroupSize)
|
||||
return failure();
|
||||
|
||||
// Get sets of operations within the function that read/write to shared
|
||||
// memory.
|
||||
SmallVector<Operation *, 16> shmReadOps;
|
||||
SmallVector<Operation *, 16> shmWriteOps;
|
||||
if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
|
||||
shmWriteOps)))
|
||||
return failure();
|
||||
|
||||
if (shmReadOps.empty() || shmWriteOps.empty())
|
||||
return failure();
|
||||
|
||||
OpBuilder builder(parentOp->getContext());
|
||||
|
||||
int64_t tgtDim = memRefType.getRank() - 1;
|
||||
int64_t srcDim = memRefType.getRank() - 2;
|
||||
|
||||
// Transform indices for the ops writing to shared memory.
|
||||
while (!shmWriteOps.empty()) {
|
||||
Operation *shmWriteOp = shmWriteOps.pop_back_val();
|
||||
builder.setInsertionPoint(shmWriteOp);
|
||||
|
||||
auto indices = amdgpu::getIndices(shmWriteOp);
|
||||
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
|
||||
transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
|
||||
memRefType, srcDim, tgtDim);
|
||||
amdgpu::setIndices(shmWriteOp, transformedIndices);
|
||||
}
|
||||
|
||||
// Transform indices for the ops reading from shared memory.
|
||||
while (!shmReadOps.empty()) {
|
||||
Operation *shmReadOp = shmReadOps.pop_back_val();
|
||||
builder.setInsertionPoint(shmReadOp);
|
||||
|
||||
auto indices = amdgpu::getIndices(shmReadOp);
|
||||
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
|
||||
transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
|
||||
memRefType, srcDim, tgtDim);
|
||||
amdgpu::setIndices(shmReadOp, transformedIndices);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
struct OptimizeSharedMemoryPass
|
||||
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
|
||||
public:
|
||||
OptimizeSharedMemoryPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
SmallVector<memref::AllocOp> shmAllocOps;
|
||||
op->walk([&](memref::AllocOp allocOp) {
|
||||
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(
|
||||
allocOp.getType()))
|
||||
return;
|
||||
shmAllocOps.push_back(allocOp);
|
||||
});
|
||||
for (auto allocOp : shmAllocOps) {
|
||||
if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
|
||||
allocOp.getMemref())))
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
39
mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
Normal file
39
mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
Normal file
@ -0,0 +1,39 @@
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
|
||||
|
||||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::amdgpu;
|
||||
|
||||
std::optional<Operation::operand_range> amdgpu::getIndices(Operation *op) {
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
|
||||
return loadOp.getIndices();
|
||||
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
|
||||
return storeOp.getIndices();
|
||||
if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
|
||||
return vectorReadOp.getIndices();
|
||||
if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
|
||||
return vectorStoreOp.getIndices();
|
||||
if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
|
||||
return transferReadOp.getIndices();
|
||||
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
|
||||
return transferWriteOp.getIndices();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
|
||||
return loadOp.getIndicesMutable().assign(indices);
|
||||
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
|
||||
return storeOp.getIndicesMutable().assign(indices);
|
||||
if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
|
||||
return vectorReadOp.getIndicesMutable().assign(indices);
|
||||
if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
|
||||
return vectorStoreOp.getIndicesMutable().assign(indices);
|
||||
if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
|
||||
return transferReadOp.getIndicesMutable().assign(indices);
|
||||
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
|
||||
return transferWriteOp.getIndicesMutable().assign(indices);
|
||||
}
|
57
mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
Normal file
57
mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
Normal file
@ -0,0 +1,57 @@
|
||||
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
|
||||
|
||||
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
|
||||
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
|
||||
%readRow: index, %readCol: index,
|
||||
%writeRow: index, %writeCol: index,
|
||||
%fragRow: index, %fragCol: index,
|
||||
%fragColPerm: index,
|
||||
%stRow: index, %stCol: index) {
|
||||
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
|
||||
%cst = arith.constant 0.000000e+00 : f16
|
||||
|
||||
// CHECK: [[shmA:%.+]] = memref.alloc
|
||||
// CHECK: [[shmB:%.+]] = memref.alloc
|
||||
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
|
||||
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
|
||||
|
||||
// CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
|
||||
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
|
||||
// CHECK: [[c7:%.+]] = arith.constant 7 : index
|
||||
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
|
||||
// CHECK: [[c2:%.+]] = arith.constant 2 : index
|
||||
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
|
||||
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
|
||||
// CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
|
||||
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
|
||||
gpu.barrier
|
||||
gpu.barrier
|
||||
// CHECK: [[c7:%.+]] = arith.constant 7 : index
|
||||
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
|
||||
// CHECK: [[c2:%.+]] = arith.constant 2 : index
|
||||
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
|
||||
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
|
||||
// CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
|
||||
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
|
||||
|
||||
// CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
|
||||
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
|
||||
// CHECK: [[c7:%.+]] = arith.constant 7 : index
|
||||
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
|
||||
// CHECK: [[c2:%.+]] = arith.constant 2 : index
|
||||
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
|
||||
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
|
||||
// CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
|
||||
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
|
||||
gpu.barrier
|
||||
gpu.barrier
|
||||
// CHECK: [[c7:%.+]] = arith.constant 7 : index
|
||||
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
|
||||
// CHECK: [[c2:%.+]] = arith.constant 2 : index
|
||||
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
|
||||
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
|
||||
// CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
|
||||
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user