[Bwd,Sm100] Shorten PipelineTmaUmma create

This commit is contained in:
Tri Dao
2026-02-08 21:11:17 +07:00
parent c912a37d52
commit deb183092b
3 changed files with 17 additions and 118 deletions
+4 -4
View File
@@ -957,7 +957,7 @@ class FlashAttentionBackwardSm100:
consumer_group=pipeline_consumer_group_compute,
tx_count=self.tma_copy_bytes["LSE"],
# cta_layout_vmnk=cluster_layout_vmnk,
# init_wait=False,
defer_sync=True,
)
pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create(
barrier_storage=storage.dPsum_mbar_ptr.data_ptr(),
@@ -966,7 +966,7 @@ class FlashAttentionBackwardSm100:
consumer_group=pipeline_consumer_group_compute,
tx_count=self.tma_copy_bytes["dPsum"],
# cta_layout_vmnk=cluster_layout_vmnk,
# init_wait=False,
defer_sync=True,
)
pipeline_Q = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.Q_mbar_ptr.data_ptr(),
@@ -975,7 +975,7 @@ class FlashAttentionBackwardSm100:
consumer_group=pipeline_consumer_group,
tx_count=self.tma_copy_bytes["Q"],
cta_layout_vmnk=cluster_layout_vmnk,
init_wait=False,
defer_sync=True,
)
pipeline_dO = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.dO_mbar_ptr.data_ptr(),
@@ -984,7 +984,7 @@ class FlashAttentionBackwardSm100:
consumer_group=pipeline_consumer_group,
tx_count=self.tma_copy_bytes["dO"],
cta_layout_vmnk=cluster_layout_vmnk,
init_wait=True,
defer_sync=False,
)
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype)
+1 -1
View File
@@ -1068,7 +1068,7 @@ class FlashAttentionBackwardSm90:
)
# Smem copy atom tiling
smem_copy_atom_PdS = utils.get_smem_store_atom(
smem_copy_atom_PdS = copy_utils.get_smem_store_atom(
self.arch, self.dtype, transpose=self.SdP_swapAB
)
smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice(
+12 -113
View File
@@ -4,48 +4,14 @@
from typing import Optional
from dataclasses import dataclass
import cutlass
import cutlass.cute as cute
from cutlass import Boolean, Int32, const_expr
from cutlass.cutlass_dsl import if_generate
from cutlass.pipeline import PipelineState, Agent, CooperativeGroup
from cutlass.pipeline import PipelineUserType, PipelineOp
from cutlass.pipeline import PipelineState
from cutlass.pipeline import PipelineUserType
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
# We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
"""
Fences the mbarrier init and syncs the threadblock or cluster
"""
cute.arch.mbarrier_init_fence()
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
# If not using clusters, sync the threadblock
_sync(Agent.ThreadBlock)
else:
# If using clusters, sync the cluster
_sync(Agent.ThreadBlockCluster)
def _sync(group: Agent):
"""
Syncs all threads within an agent.
"""
if group is Agent.Thread:
raise NotImplementedError("Error: Not supported.")
elif group is Agent.ThreadBlock:
cute.arch.sync_threads()
elif group is Agent.ThreadBlockCluster:
cute.arch.cluster_arrive_relaxed()
cute.arch.cluster_wait()
else:
assert False, (
"Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
)
class PipelineStateSimple:
"""
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
@@ -172,84 +138,17 @@ class PipelineTmaAsync(PipelineTmaAsyncOg):
@dataclass(frozen=True)
class PipelineTmaUmma(PipelineTmaUmmaOg):
"""
Override producer_acquire to take in extra_tx_count parameter.
"""
@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
tx_count: int,
barrier_storage: cute.Pointer = None,
cta_layout_vmnk: Optional[cute.Layout] = None,
mcast_mode_mn: tuple[int, int] = (1, 1),
init_wait: cutlass.Constexpr[bool] = True,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: `CooperativeGroup` for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: `CooperativeGroup` for the consumer agent
:type consumer_group: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
:param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1.
:type mcast_mode_mn: tuple[int, int]
"""
if not isinstance(barrier_storage, cute.Pointer):
raise ValueError(
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)
producer_type = PipelineOp.TmaLoad
consumer_type = PipelineOp.TCGen05Mma
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_full = PipelineTmaUmmaOg._make_sync_object(
barrier_storage.align(min_align=8), num_stages, producer, tx_count
)
sync_object_empty = PipelineTmaUmmaOg._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
# No mcast mask if not using clusters
producer_mask = None
# All threadblocks are leaders if not using clusters
is_leader_cta = True
else:
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
cta_layout_vmnk, mcast_mode_mn
)
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
cta_group = (
cute.nvgpu.tcgen05.CtaGroup.ONE
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
else cute.nvgpu.tcgen05.CtaGroup.TWO
)
consumer_mask = producer_mask
if const_expr(init_wait):
pipeline_init_wait(cta_layout_vmnk)
return PipelineTmaUmma(
sync_object_full,
sync_object_empty,
num_stages,
producer_mask,
consumer_mask,
is_leader_cta,
cta_group,
)
def create(*args, **kwargs):
obj = PipelineTmaUmmaOg.create(*args, **kwargs)
# Can't assign to __class__ directly since the dataclass is frozen
# obj.__class__ = PipelineTmaUmma
object.__setattr__(obj, "__class__", PipelineTmaUmma)
return obj
def producer_acquire(
self,