D1.5: Revert broken TMEM round-trip O rescale, document as fundamentally broken
TMEM round-trip via Ld32x32bOp/St32x32bOp corrupts O accumulator data even with CUTLASS correction_rescale pattern. All variants tested: - Repetition(16) + composition (CUTLASS exact pattern) — BROKEN - Repetition(32) + composition — BROKEN - Repetition(16) raw layout (no composition) — BROKEN Even NO-OP (multiply by 1.0) produces catastrophically wrong results. Production path remains Python KV merge (cos 0.999998 for s_k up to 1024). Next: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
This commit is contained in:
@@ -21,7 +21,7 @@ import math
|
||||
|
||||
|
||||
class FmhaKernel:
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False, debug_noop_rescale=False):
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False):
|
||||
# D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to
|
||||
# positions >= n_comp. D3/D4 masks also only apply to SWA region.
|
||||
# When n_comp is None or 0, no offset (backward compatible).
|
||||
@@ -58,8 +58,6 @@ class FmhaKernel:
|
||||
self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
|
||||
self.q_stage = 1
|
||||
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
|
||||
self.debug_noop_rescale = debug_noop_rescale # D1.5 debug: force acc_scale=1.0 in O rescale
|
||||
|
||||
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
|
||||
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
|
||||
|
||||
@@ -189,9 +187,8 @@ class FmhaKernel:
|
||||
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
|
||||
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
# D1.5: barrier for PV completion signal (MMA→softmax warps)
|
||||
# MMA warp arrives after PV[kt] completes; softmax warps wait before O rescale.
|
||||
pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
# D1.5: pv_done_bar for O rescale (currently unused — TMEM round-trip broken)
|
||||
# pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
@@ -319,8 +316,7 @@ class FmhaKernel:
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
kvh_v.release()
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
pv_done_bar.arrive() # Signal softmax warps: PV done, O is ready for rescale
|
||||
# pv_done_bar.arrive() # D1.5: unused — TMEM round-trip broken
|
||||
final_o_bar.arrive()
|
||||
else:
|
||||
# Original pipeline path (hd≤256)
|
||||
@@ -338,8 +334,6 @@ class FmhaKernel:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
sh.commit()
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
cute.arch.fence_view_async_tmem_load() # Ensure rescaled O visible before PV[kt]
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
||||
if not self.use_smem_p:
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
@@ -351,8 +345,7 @@ class FmhaKernel:
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
kvh.release()
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
pv_done_bar.arrive() # Signal softmax warps: PV done, O ready for rescale
|
||||
# pv_done_bar.arrive() # D1.5: unused — TMEM round-trip broken
|
||||
acc_pipe.producer_commit(acc_st); acc_st.advance()
|
||||
final_o_bar.arrive()
|
||||
acc_pipe.producer_tail(acc_st)
|
||||
@@ -409,53 +402,22 @@ class FmhaKernel:
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# ============================================================
|
||||
# D1.5: O RESCALE ATOMS (CUTLASS correction_rescale pattern)
|
||||
# D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH
|
||||
# =================================================
|
||||
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
|
||||
# even NO-OP round-trip corrupts data (ratio = -11 billion).
|
||||
# Instead, we use one-way TMEM→REGS→SMEM after each PV,
|
||||
# accumulate in SMEM with acc_scale multiplication, and
|
||||
# TMA store SMEM→GMEM after all kt iterations.
|
||||
#
|
||||
# For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store
|
||||
# path works perfectly (cos=0.999998). The SMEM accumulator
|
||||
# is only needed for n_kv_tiles > 1.
|
||||
# ============================================================
|
||||
# Pattern: both load and store atoms built from the SAME tOtO_i
|
||||
# (composition-tiled from tOtO0), same Repetition(corr_tile_size).
|
||||
# This is the exact pattern from CUTLASS reference fmha.py line 2123.
|
||||
# The key insight: using composition() to re-tile tOtO into (128, corr_tile_size)
|
||||
# sub-tiles, and building BOTH copies from the SAME tensor, ensures the
|
||||
# column mappings agree on round-trip.
|
||||
# ============================================================
|
||||
corr_tile_size = 16 # Must be power of 2, divides head_dim
|
||||
# Try both composition and raw layout
|
||||
use_comp = True
|
||||
if const_expr(use_comp):
|
||||
tOtO_i_layout = cute.composition(
|
||||
tOtO0.layout, cute.make_layout((128, corr_tile_size))
|
||||
)
|
||||
else:
|
||||
tOtO_i_layout = tOtO0.layout
|
||||
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
|
||||
|
||||
# Coordinate tensor for O (needed for partition_D of load)
|
||||
cO = cute.make_identity_tensor((128, self.head_dim))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
if const_expr(use_comp):
|
||||
tOcO_i_layout = cute.composition(
|
||||
tOcO.layout, cute.make_layout((128, corr_tile_size))
|
||||
)
|
||||
else:
|
||||
tOcO_i_layout = tOcO.layout
|
||||
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
|
||||
|
||||
tmem_load_o_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
|
||||
self.qk_acc_dtype,
|
||||
)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
|
||||
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
|
||||
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
|
||||
|
||||
tmem_store_o_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
|
||||
self.qk_acc_dtype,
|
||||
)
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
|
||||
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
|
||||
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
|
||||
# NOTE: The code below is the BROKEN TMEM round-trip approach.
|
||||
# It's kept as reference but should NOT be used.
|
||||
# The SMEM accumulator implementation is TODO.
|
||||
|
||||
# prev_acc_scale: unused, kept for clarity. acc_scale at kt is used
|
||||
# to rescale O from kt=0..kt-1 before PV[kt].
|
||||
@@ -559,40 +521,12 @@ class FmhaKernel:
|
||||
k2 = k_coord // 64
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
# D1.5: O rescale for kt > 0 — CUTLASS correction_rescale pattern.
|
||||
# After computing acc_scale for this iteration, rescale the existing O
|
||||
# in TMEM before the next PV GEMM adds to it.
|
||||
# Must wait for PV[kt-1] to complete (MMA signals pv_done_bar).
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
if kt > 0:
|
||||
pv_done_bar.arrive_and_wait() # Wait for PV[kt-1]
|
||||
# Rescale O: load, multiply by acc_scale, store back to TMEM.
|
||||
# CUTLASS pattern: both copies use same tOtO_i (composition-tiled).
|
||||
rescale_factor = acc_scale
|
||||
if const_expr(self.debug_noop_rescale):
|
||||
rescale_factor = Float32(1.0)
|
||||
n_slices = self.head_dim // corr_tile_size
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, n_slices), self.qk_acc_dtype
|
||||
)
|
||||
for i in range(n_slices):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[k] = tTMrO_i[k] * rescale_factor
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
# D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED.
|
||||
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
|
||||
# even NO-OP round-trip corrupts O accumulator data.
|
||||
# Production path for multi-KV-tile: Python KV merge (cos 0.999998).
|
||||
# Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
|
||||
# n_kv_tiles=1 is the only supported path for in-kernel processing.
|
||||
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
151
dsv4/kernels/attention/fmha_smem_acc.py
Normal file
151
dsv4/kernels/attention/fmha_smem_acc.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
|
||||
|
||||
SMEM accumulator approach for multi-KV-tile O rescale.
|
||||
Instead of TMEM round-trip (which corrupts data), we move O from TMEM
|
||||
to SMEM after each PV GEMM via one-way epilogue, and accumulate in SMEM.
|
||||
|
||||
This avoids the D1.5 TMEM round-trip bug entirely.
|
||||
|
||||
Architecture:
|
||||
- 6-warp specialization: 4 softmax+epilogue, 1 MMA, 1 TMA
|
||||
- After PV[kt]: one-way TMEM→REGS→SMEM with acc_scale multiplication
|
||||
- SMEM accumulator persists across kt iterations
|
||||
- Final TMA store: SMEM→GMEM
|
||||
|
||||
Per-kt flow:
|
||||
1. Softmax warps: compute P[kt], acc_scale[kt]
|
||||
2. Signal softmax_done_bar
|
||||
3. MMA warp: PV[kt] GEMM (ACCUMULATE=False, fresh TMEM)
|
||||
4. Signal pv_done_bar
|
||||
5. Softmax/epilogue warps: TMEM→REGS, acc_scale*O_acc + O_kt, REGS→SMEM
|
||||
6. Repeat for next kt
|
||||
7. After all kt: SMEM→GMEM via TMA
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
from cutlass.utils.blackwell_helpers import get_smem_store_op
|
||||
from cutlass.utils.gemm.sm100 import (
|
||||
transform_partitioned_tensor_layout,
|
||||
epilogue_tmem_copy_and_partition,
|
||||
epilogue_smem_copy_and_partition,
|
||||
)
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
import math
|
||||
|
||||
|
||||
class FmhaKernel:
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True,
|
||||
num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False,
|
||||
n_comp=None, apply_sink_bias=False):
|
||||
self.n_comp = n_comp if n_comp is not None else 0
|
||||
self.apply_sink_bias = apply_sink_bias
|
||||
self.head_dim = head_dim
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
self.pv_n_tile = min(head_dim, 256)
|
||||
if head_dim > 256:
|
||||
self.pv_n_tile = 128
|
||||
self.n_pv_tiles = head_dim // self.pv_n_tile
|
||||
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
|
||||
self.num_query_heads = num_query_heads
|
||||
self.batch_size = batch_size
|
||||
self.normalize = normalize
|
||||
self.apply_swa_mask = apply_swa_mask
|
||||
self.is_causal = is_causal
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
|
||||
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (128, 128, self.k_tile)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2])
|
||||
self.c_layout = LayoutEnum.ROW_MAJOR
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
|
||||
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
self.tmem_s0_offset = 0
|
||||
if not self.use_smem_p:
|
||||
self.tmem_p0_offset = 32
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
p_end = self.tmem_p0_offset + p_cols_fp32
|
||||
s_cols = self.qk_mma_tiler[1]
|
||||
o_after = max(s_cols, p_end)
|
||||
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
total = self.tmem_o0_offset + o_cols
|
||||
else:
|
||||
self.tmem_p0_offset = -1
|
||||
self.tmem_o0_offset = 0
|
||||
s_cols = self.qk_mma_tiler[1]
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
total = max(s_cols, o_cols)
|
||||
self.num_tmem_alloc_cols = 1
|
||||
while self.num_tmem_alloc_cols < total:
|
||||
self.num_tmem_alloc_cols *= 2
|
||||
self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2
|
||||
cta = cute.size(qk_mma.thr_id.shape)
|
||||
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
|
||||
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
|
||||
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
|
||||
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
|
||||
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
|
||||
cute.size_in_bytes(self.q_dtype, v_s)) * cta
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
v_fmha = cute.make_tensor(
|
||||
v.iterator,
|
||||
cute.make_layout(
|
||||
(self.pv_n_tile, self.s_k, 1),
|
||||
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
|
||||
),
|
||||
)
|
||||
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
|
||||
pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K
|
||||
pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,self.pv_n_tile), pv_source)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
|
||||
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
|
||||
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
|
||||
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
|
||||
epi_s = cute.select(self.c_smem_s,mode=[0,1])
|
||||
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
|
||||
if const_expr(lse is None):
|
||||
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
|
||||
if const_expr(swa_len is None):
|
||||
swa_len = Int32(2147483647)
|
||||
else:
|
||||
swa_len = Int32(swa_len)
|
||||
if const_expr(sink_bias is None):
|
||||
sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,)))
|
||||
if const_expr(row_sums is None):
|
||||
row_sums = cute.make_tensor(lse.iterator, lse.layout)
|
||||
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
# ... rest of kernel to be implemented
|
||||
Reference in New Issue
Block a user