mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
[Bwd,Sm100] Shorten PipelineTmaUmma create
This commit is contained in:
@@ -957,7 +957,7 @@ class FlashAttentionBackwardSm100:
|
||||
consumer_group=pipeline_consumer_group_compute,
|
||||
tx_count=self.tma_copy_bytes["LSE"],
|
||||
# cta_layout_vmnk=cluster_layout_vmnk,
|
||||
# init_wait=False,
|
||||
defer_sync=True,
|
||||
)
|
||||
pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create(
|
||||
barrier_storage=storage.dPsum_mbar_ptr.data_ptr(),
|
||||
@@ -966,7 +966,7 @@ class FlashAttentionBackwardSm100:
|
||||
consumer_group=pipeline_consumer_group_compute,
|
||||
tx_count=self.tma_copy_bytes["dPsum"],
|
||||
# cta_layout_vmnk=cluster_layout_vmnk,
|
||||
# init_wait=False,
|
||||
defer_sync=True,
|
||||
)
|
||||
pipeline_Q = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.Q_mbar_ptr.data_ptr(),
|
||||
@@ -975,7 +975,7 @@ class FlashAttentionBackwardSm100:
|
||||
consumer_group=pipeline_consumer_group,
|
||||
tx_count=self.tma_copy_bytes["Q"],
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
init_wait=False,
|
||||
defer_sync=True,
|
||||
)
|
||||
pipeline_dO = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.dO_mbar_ptr.data_ptr(),
|
||||
@@ -984,7 +984,7 @@ class FlashAttentionBackwardSm100:
|
||||
consumer_group=pipeline_consumer_group,
|
||||
tx_count=self.tma_copy_bytes["dO"],
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
init_wait=True,
|
||||
defer_sync=False,
|
||||
)
|
||||
|
||||
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype)
|
||||
|
||||
@@ -1068,7 +1068,7 @@ class FlashAttentionBackwardSm90:
|
||||
)
|
||||
|
||||
# Smem copy atom tiling
|
||||
smem_copy_atom_PdS = utils.get_smem_store_atom(
|
||||
smem_copy_atom_PdS = copy_utils.get_smem_store_atom(
|
||||
self.arch, self.dtype, transpose=self.SdP_swapAB
|
||||
)
|
||||
smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice(
|
||||
|
||||
+12
-113
@@ -4,48 +4,14 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass import Boolean, Int32, const_expr
|
||||
from cutlass.cutlass_dsl import if_generate
|
||||
from cutlass.pipeline import PipelineState, Agent, CooperativeGroup
|
||||
from cutlass.pipeline import PipelineUserType, PipelineOp
|
||||
from cutlass.pipeline import PipelineState
|
||||
from cutlass.pipeline import PipelineUserType
|
||||
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
|
||||
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
|
||||
|
||||
|
||||
# We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed
|
||||
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
|
||||
"""
|
||||
Fences the mbarrier init and syncs the threadblock or cluster
|
||||
"""
|
||||
cute.arch.mbarrier_init_fence()
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
# If not using clusters, sync the threadblock
|
||||
_sync(Agent.ThreadBlock)
|
||||
else:
|
||||
# If using clusters, sync the cluster
|
||||
_sync(Agent.ThreadBlockCluster)
|
||||
|
||||
|
||||
def _sync(group: Agent):
|
||||
"""
|
||||
Syncs all threads within an agent.
|
||||
"""
|
||||
if group is Agent.Thread:
|
||||
raise NotImplementedError("Error: Not supported.")
|
||||
elif group is Agent.ThreadBlock:
|
||||
cute.arch.sync_threads()
|
||||
elif group is Agent.ThreadBlockCluster:
|
||||
cute.arch.cluster_arrive_relaxed()
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
assert False, (
|
||||
"Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
|
||||
)
|
||||
|
||||
|
||||
class PipelineStateSimple:
|
||||
"""
|
||||
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
|
||||
@@ -172,84 +138,17 @@ class PipelineTmaAsync(PipelineTmaAsyncOg):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaUmma(PipelineTmaUmmaOg):
|
||||
"""
|
||||
Override producer_acquire to take in extra_tx_count parameter.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
num_stages: int,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
tx_count: int,
|
||||
barrier_storage: cute.Pointer = None,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
mcast_mode_mn: tuple[int, int] = (1, 1),
|
||||
init_wait: cutlass.Constexpr[bool] = True,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: `CooperativeGroup` for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: `CooperativeGroup` for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
:param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1.
|
||||
:type mcast_mode_mn: tuple[int, int]
|
||||
"""
|
||||
if not isinstance(barrier_storage, cute.Pointer):
|
||||
raise ValueError(
|
||||
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
||||
)
|
||||
|
||||
producer_type = PipelineOp.TmaLoad
|
||||
consumer_type = PipelineOp.TCGen05Mma
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_full = PipelineTmaUmmaOg._make_sync_object(
|
||||
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
||||
)
|
||||
sync_object_empty = PipelineTmaUmmaOg._make_sync_object(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
# No mcast mask if not using clusters
|
||||
producer_mask = None
|
||||
# All threadblocks are leaders if not using clusters
|
||||
is_leader_cta = True
|
||||
else:
|
||||
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
|
||||
cta_layout_vmnk, mcast_mode_mn
|
||||
)
|
||||
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
||||
|
||||
cta_group = (
|
||||
cute.nvgpu.tcgen05.CtaGroup.ONE
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
||||
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
||||
)
|
||||
|
||||
consumer_mask = producer_mask
|
||||
|
||||
if const_expr(init_wait):
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineTmaUmma(
|
||||
sync_object_full,
|
||||
sync_object_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
is_leader_cta,
|
||||
cta_group,
|
||||
)
|
||||
def create(*args, **kwargs):
|
||||
obj = PipelineTmaUmmaOg.create(*args, **kwargs)
|
||||
# Can't assign to __class__ directly since the dataclass is frozen
|
||||
# obj.__class__ = PipelineTmaUmma
|
||||
object.__setattr__(obj, "__class__", PipelineTmaUmma)
|
||||
return obj
|
||||
|
||||
def producer_acquire(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user