D1.3: SMEM-P via get_smem_store_op + make_tiled_copy_D

Uses the CUTLASS blackwell_helpers pattern:
- get_smem_store_op creates a SMEM store atom paired with the TMEM load
- make_tiled_copy_D uses the same thread partition as the TMEM load
- Softmax warps write P to sP using the same thread mapping they use for reading S
- MMA warp reads P from sP via pv_mma.make_fragment_A(sP)
- Replaces the zero-fill stub with a proper register→SMEM copy
This commit is contained in:
2026-05-23 22:26:09 +00:00
parent 7771e5a72b
commit 6aa519d5ec

View File

@@ -9,6 +9,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.blackwell_helpers import get_smem_store_op
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
@@ -265,10 +266,30 @@ class FmhaKernel:
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P (TBD)
# make_tiled_copy_C gives rank mismatch (QK C-fragment has 4 modes,
# PV A-operand SMEM has 3 modes). Need proper layout-aware copy.
# For now, SMEM-P path zero-fills sP. TMEM-P (hd<=64) works correctly.
# P SMEM copy atoms: SMEM-P
# Uses get_smem_store_op + make_tiled_copy_D from CUTLASS blackwell_helpers.
# This creates a SMEM store copy using the same thread partition as the TMEM load,
# so the same threads that compute P (softmax warps) can write P to sP directly.
# The MMA warp then reads P from sP via pv_mma.make_fragment_A(sP).
if self.use_smem_p:
_p_smem_store_atom = get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load
)
_tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load)
_thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx)
# Partition sP for the SMEM store (destination)
_sP_2d = cute.group_modes(sP, 0, 3)
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(_sP_2d)
# Create a source register tensor for the SMEM store
# The source layout comes from the TMEM load's source partition (tTMEM_LOADtS)
# But we need a BF16 view of P, not the FP32 S view.
# The softmax writes P values (BF16) via the register bridge rP_bf16,
# which shares layout with rS (tTMEM_LOADrS layout).
# make_tiled_copy_D uses the same thread partition, so the source partition
# should match the TMEM load source partition.
_tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS)
# But rS is FP32. We need BF16. Create BF16 view of rP with same layout.
_rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype)
row_max = -Float32.inf
row_sum = Float32(0.0)
@@ -341,12 +362,13 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: zero-fill sP stub (proper layout-aware copy TBD)
# make_tiled_copy_C gives rank mismatch (4 vs 3).
# Need proper P register→SMEM copy that respects QK C-fragment layout
# and PV A-operand SMEM layout. For now, TMEM-P (hd<=64) works.
for j in cutlass.range(cute.size(sP), vectorize=True):
sP[j] = self.q_dtype(0)
# SMEM-P: store P to SMEM via make_tiled_copy_D
# The P values are in rP_bf16 (BF16 view of the FP32 register bridge).
# Copy the BF16 P values to the SMEM store source registers.
for j in cutlass.range(cute.size(rP_bf16), vectorize=True):
_rP_smem[j] = rP_bf16[j]
# Write P to sP (PV A-operand SMEM layout)
cute.copy(_tiled_smem_store_p, _rP_smem, _tSMEM_STOREsP)
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
tTMrO = cute.make_rmem_tensor(