D1.3: SMEM-P via get_smem_store_op + make_tiled_copy_D
Uses the CUTLASS blackwell_helpers pattern: - get_smem_store_op creates a SMEM store atom paired with the TMEM load - make_tiled_copy_D uses the same thread partition as the TMEM load - Softmax warps write P to sP using the same thread mapping they use for reading S - MMA warp reads P from sP via pv_mma.make_fragment_A(sP) - Replaces the zero-fill stub with a proper register→SMEM copy
This commit is contained in:
@@ -9,6 +9,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
from cutlass.utils.blackwell_helpers import get_smem_store_op
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
import math
|
||||
@@ -265,10 +266,30 @@ class FmhaKernel:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# P SMEM copy atoms: SMEM-P (TBD)
|
||||
# make_tiled_copy_C gives rank mismatch (QK C-fragment has 4 modes,
|
||||
# PV A-operand SMEM has 3 modes). Need proper layout-aware copy.
|
||||
# For now, SMEM-P path zero-fills sP. TMEM-P (hd<=64) works correctly.
|
||||
# P SMEM copy atoms: SMEM-P
|
||||
# Uses get_smem_store_op + make_tiled_copy_D from CUTLASS blackwell_helpers.
|
||||
# This creates a SMEM store copy using the same thread partition as the TMEM load,
|
||||
# so the same threads that compute P (softmax warps) can write P to sP directly.
|
||||
# The MMA warp then reads P from sP via pv_mma.make_fragment_A(sP).
|
||||
if self.use_smem_p:
|
||||
_p_smem_store_atom = get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load
|
||||
)
|
||||
_tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load)
|
||||
_thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx)
|
||||
# Partition sP for the SMEM store (destination)
|
||||
_sP_2d = cute.group_modes(sP, 0, 3)
|
||||
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(_sP_2d)
|
||||
# Create a source register tensor for the SMEM store
|
||||
# The source layout comes from the TMEM load's source partition (tTMEM_LOADtS)
|
||||
# But we need a BF16 view of P, not the FP32 S view.
|
||||
# The softmax writes P values (BF16) via the register bridge rP_bf16,
|
||||
# which shares layout with rS (tTMEM_LOADrS layout).
|
||||
# make_tiled_copy_D uses the same thread partition, so the source partition
|
||||
# should match the TMEM load source partition.
|
||||
_tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS)
|
||||
# But rS is FP32. We need BF16. Create BF16 view of rP with same layout.
|
||||
_rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -341,12 +362,13 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: zero-fill sP stub (proper layout-aware copy TBD)
|
||||
# make_tiled_copy_C gives rank mismatch (4 vs 3).
|
||||
# Need proper P register→SMEM copy that respects QK C-fragment layout
|
||||
# and PV A-operand SMEM layout. For now, TMEM-P (hd<=64) works.
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = self.q_dtype(0)
|
||||
# SMEM-P: store P to SMEM via make_tiled_copy_D
|
||||
# The P values are in rP_bf16 (BF16 view of the FP32 register bridge).
|
||||
# Copy the BF16 P values to the SMEM store source registers.
|
||||
for j in cutlass.range(cute.size(rP_bf16), vectorize=True):
|
||||
_rP_smem[j] = rP_bf16[j]
|
||||
# Write P to sP (PV A-operand SMEM layout)
|
||||
cute.copy(_tiled_smem_store_p, _rP_smem, _tSMEM_STOREsP)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
|
||||
Reference in New Issue
Block a user