[mlir][amdgpu] Shared memory access optimization pass (#75627)

It implements transformation to optimize accesses to shared memory.

Reference: https://reviews.llvm.org/D127457

_This change adds a transformation and pass to the NvGPU dialect that
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts. Given a value representing a
shared memory memref, it traverses all reads/writes within the parent op
and, subject to suitable conditions, rewrites all last dimension index
values such that element locations in the final (col) dimension are
given by newColIdx = col % vecSize + perm[row](col / vecSize, row)
where perm is a permutation function indexed by row and vecSize
is the vector access size in elements (currently assumes 128bit
vectorized accesses, but this can be made a parameter). This specific
transformation can help optimize typical distributed & vectorized
accesses
common to loading matrix multiplication operands to/from shared memory._
This commit is contained in:
erman-gurses 2024-01-19 18:44:45 -05:00 committed by GitHub
parent 30aa9fb4c1
commit b7360fbe8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 466 additions and 1 deletions

View File

@ -29,6 +29,23 @@ def AMDGPU_Dialect : Dialect {
"gpu::GPUDialect"
];
let useDefaultAttributePrinterParser = 1;
let extraClassDeclaration = [{
/// Return true if the given MemRefType has an integer address
/// space that matches the ROCDL shared memory address space or
/// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
static bool hasSharedMemoryAddressSpace(MemRefType type);
/// Return true if the given Attribute has an integer address
/// space that matches the ROCDL shared memory address space or
/// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
static bool isSharedMemoryAddressSpace(Attribute type);
/// Defines the MemRef memory space attribute numeric value that indicates
/// a memref is located in shared memory. This should correspond to the
/// value used in ROCDL.
static constexpr unsigned kSharedMemoryAddressSpace = 3;
}];
}
//===----------------------------------------------------------------------===//

View File

@ -20,7 +20,8 @@ namespace mlir {
class ConversionTarget;
namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"

View File

@ -30,4 +30,17 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
"Chipset that these operations will run on">];
}
def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
let description = [{
This pass adds a transformation and pass to the AMDGPU dialect that
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts.
}];
let dependentDialects = [
"memref::MemRefDialect", "vector::VectorDialect"
];
}
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

View File

@ -0,0 +1,54 @@
//===- Transforms.h - AMDGPU Dialect transformations --------------*-
// C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares functions that assist transformations for the amdgpu
// dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
class RewriterBase;
namespace amdgpu {
///
/// Passes
///
/// Optimizes vectorized accesses to a shared memory buffer specified by
/// memrefValue. This transformation assumes the following:
/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
/// 2) The function will fail precondition checks if any subviews are
/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur
/// through `memrefValue` directly.
///
/// Shared memory bank conflicts occur when multiple threads attempt to read or
/// write locations assigned to the same shared memory bank. For `2^N` byte
/// vectorized accesses, we need to be concerned with conflicts among threads
/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation
/// changes any indexed memory access (vector.load, memref.load, etc)
/// such that the final dimension's index value is permuted such that
/// `newColIndex = oldColIndex % vectorSize +
/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the
/// index for the second-to last dimension and `perm[rowIndex]` is a permutation
/// function that depends on the row Index. The permutation function is chosen
/// to ensure that sequential distributed+vectorized reads/writes down a single
/// dimension of the memref have minimal conflicts.
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue);
} // namespace amdgpu
} // namespace mlir
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_

View File

@ -0,0 +1,24 @@
//===- Utils.h - Transform utilities -----------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Operation.h"
namespace mlir {
namespace amdgpu {
/// Get and set the indices that the given load/store operation is operating on.
/// Preconditions:
/// - The Op must have memory affects
/// - Considers memref::LoadOp, vector::LoadOp, vector::TransferReadOp
/// - Considers memref::StoreOp, vector::StoreOp, vector::TransferWriteOp
/// - Excludes subview op
std::optional<Operation::operand_range> getIndices(Operation *op);
void setIndices(Operation *op, ArrayRef<Value> indices);
} // namespace amdgpu
} // namespace mlir

View File

@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
>();
}
bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
if (!memorySpace)
return false;
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
return false;
}
bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
Attribute memorySpace = type.getMemorySpace();
return isSharedMemoryAddressSpace(memorySpace);
}
//===----------------------------------------------------------------------===//
// 8-bit float ops
//===----------------------------------------------------------------------===//

View File

@ -1,5 +1,7 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
OptimizeSharedMemory.cpp
Utils.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms

View File

@ -0,0 +1,243 @@
//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements transforms to optimize accesses to shared memory.
// It is inspired by
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
namespace amdgpu {
#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
} // namespace amdgpu
} // namespace mlir
using namespace mlir;
using namespace mlir::amdgpu;
/// The size of a shared memory line according to AMD documentation.
/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
constexpr int64_t kSharedMemoryLineSizeBytes = 64;
/// We optimize for 64bit accesses, but this can be made an argument in the
/// future.
constexpr int64_t kDefaultVectorSizeBits = 64;
/// Uses `srcIndexValue` to permute `tgtIndexValue` via
/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
/// floordiv(tgtIdxVal,vectorSize)))
/// + tgtIdxVal % vectorSize`
/// This is done using an optimized sequence of `arith` operations.
static Value permuteVectorOffset(OpBuilder &b, Location loc,
ArrayRef<Value> indices, MemRefType memrefTy,
int64_t srcDim, int64_t tgtDim) {
// Adjust the src index to change how often the permutation changes
// if necessary.
Value src = indices[srcDim];
// We only want to permute every N iterations of the target dim where N is
// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
const int64_t permuteEveryN = std::max<int64_t>(
1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
memrefTy.getElementTypeBitWidth()) /
8));
// clang-format off
// Index bit representation (b0 = least significant bit) for dim(1)
// of a `memref<?x?xDT>` is as follows:
// N := log2(128/elementSizeBits)
// M := log2(dimSize(1))
// then
// bits[0:N] = sub-vector element offset
// bits[N:M] = vector index
// clang-format on
int64_t n =
llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
// Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
int64_t mask = (1LL << (m - n)) - 1;
if (permuteEveryN > 1)
mask = mask << llvm::Log2_64(permuteEveryN);
Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
// Use the src bits to permute the target bits b[N:M] containing the
// vector offset.
if (permuteEveryN > 1) {
int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
if (shlBits > 0) {
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
} else if (shlBits < 0) {
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
}
} else {
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
}
Value permutedVectorIdx =
b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
return permutedVectorIdx;
}
static void transformIndices(OpBuilder &builder, Location loc,
SmallVector<Value, 4> &indices,
MemRefType memrefTy, int64_t srcDim,
int64_t tgtDim) {
indices[tgtDim] =
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
}
/// Return all operations within `parentOp` that read from or write to
/// `shmMemRef`.
static LogicalResult
getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
SmallVector<Operation *, 16> &readOps,
SmallVector<Operation *, 16> &writeOps) {
parentOp->walk([&](Operation *op) {
MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
if (!iface)
return;
std::optional<MemoryEffects::EffectInstance> effect =
iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
if (effect) {
readOps.push_back(op);
return;
}
effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
if (effect)
writeOps.push_back(op);
});
// Restrict to a supported set of ops. We also require at least 2D access,
// although this could be relaxed.
if (llvm::any_of(readOps, [](Operation *op) {
return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
op) ||
amdgpu::getIndices(op)->size() < 2;
}))
return failure();
if (llvm::any_of(writeOps, [](Operation *op) {
return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
op) ||
amdgpu::getIndices(op)->size() < 2;
}))
return failure();
return success();
}
mlir::LogicalResult
mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue) {
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
if (!memRefType ||
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
return failure();
// Abort if the given value has any sub-views; we do not do any alias
// analysis.
bool hasSubView = false;
parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
if (hasSubView)
return failure();
// Check if this is necessary given the assumption of 128b accesses:
// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
const int64_t rowsPerLine =
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
rowSize;
const int64_t threadGroupSize =
1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
if (rowsPerLine >= threadGroupSize)
return failure();
// Get sets of operations within the function that read/write to shared
// memory.
SmallVector<Operation *, 16> shmReadOps;
SmallVector<Operation *, 16> shmWriteOps;
if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
shmWriteOps)))
return failure();
if (shmReadOps.empty() || shmWriteOps.empty())
return failure();
OpBuilder builder(parentOp->getContext());
int64_t tgtDim = memRefType.getRank() - 1;
int64_t srcDim = memRefType.getRank() - 2;
// Transform indices for the ops writing to shared memory.
while (!shmWriteOps.empty()) {
Operation *shmWriteOp = shmWriteOps.pop_back_val();
builder.setInsertionPoint(shmWriteOp);
auto indices = amdgpu::getIndices(shmWriteOp);
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
memRefType, srcDim, tgtDim);
amdgpu::setIndices(shmWriteOp, transformedIndices);
}
// Transform indices for the ops reading from shared memory.
while (!shmReadOps.empty()) {
Operation *shmReadOp = shmReadOps.pop_back_val();
builder.setInsertionPoint(shmReadOp);
auto indices = amdgpu::getIndices(shmReadOp);
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
memRefType, srcDim, tgtDim);
amdgpu::setIndices(shmReadOp, transformedIndices);
}
return success();
}
struct OptimizeSharedMemoryPass
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
public:
OptimizeSharedMemoryPass() = default;
void runOnOperation() override {
Operation *op = getOperation();
SmallVector<memref::AllocOp> shmAllocOps;
op->walk([&](memref::AllocOp allocOp) {
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(
allocOp.getType()))
return;
shmAllocOps.push_back(allocOp);
});
for (auto allocOp : shmAllocOps) {
if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
allocOp.getMemref())))
return;
}
}
};

View File

@ -0,0 +1,39 @@
#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
using namespace mlir::amdgpu;
std::optional<Operation::operand_range> amdgpu::getIndices(Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getIndices();
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
return storeOp.getIndices();
if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
return vectorReadOp.getIndices();
if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
return vectorStoreOp.getIndices();
if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
return transferReadOp.getIndices();
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteOp.getIndices();
return std::nullopt;
}
void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getIndicesMutable().assign(indices);
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
return storeOp.getIndicesMutable().assign(indices);
if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
return vectorReadOp.getIndicesMutable().assign(indices);
if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
return vectorStoreOp.getIndicesMutable().assign(indices);
if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
return transferReadOp.getIndicesMutable().assign(indices);
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteOp.getIndicesMutable().assign(indices);
}

View File

@ -0,0 +1,57 @@
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
%readRow: index, %readCol: index,
%writeRow: index, %writeCol: index,
%fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
%cst = arith.constant 0.000000e+00 : f16
// CHECK: [[shmA:%.+]] = memref.alloc
// CHECK: [[shmB:%.+]] = memref.alloc
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
// CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
// CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
// CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
// CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
// CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
// CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
}