Files
nvfp4-megamoe-kernel/tests/unit/test_smem_p_write.py

200 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SMEM-P Write Diagnostic: Verify the coordinate-indexed write to sP.
This test:
1. Creates an S matrix (identity for simplicity)
2. Writes S values to sP using the coordinate-indexed approach
3. Reads sP back via the PV MMA's A-operand fragment
4. Verifies the round-trip
If the round-trip is correct, the coordinate extraction and sP indexing are right.
If not, we need to debug the coordinate mapping.
"""
import torch, math
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05, cpasync
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import cuda.bindings.driver as cuda
@cute.jit
def smem_p_write_test(s_data, sP_out, qk_mma, pv_mma, qk_mma_tiler, pv_mma_tiler, p_smem_s):
"""Write S values to sP and read them back."""
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
# Only use 4 softmax warps + 1 MMA warp for this test
if warp_idx >= 5:
return
# SMEM allocation
@cute.struct
class SS:
dummy: cute.struct.MemRange[cutlass.Int64, 2]
smem = utils.SmemAllocator()
st = smem.allocate(SS)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * 5) # 5 warps
# Allocate sP in SMEM
sP = smem.allocate_tensor(element_type=BFloat16, layout=p_smem_s.outer, byte_alignment=128, swizzle=p_smem_s.inner)
sP_nostage = sP[(None, None, None, 0)]
# Allocate TMEM for S
tmem = utils.TmemAllocator(st.dummy.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=0, is_two_cta=False)
if warp_idx < 4:
tmem.allocate(128) # enough for 128×128 FP32 S
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(Float32)
# QK C-fragment (S in TMEM)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# TMEM-load copy
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
sfw_idx = tidx % 128 # 4 softmax warps
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
# Coordinate identity tensor
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# Softmax warps: read S from TMEM, write to sP
if warp_idx < 4:
# First: copy input data to TMEM (S data)
# Use the TMEM store to write s_data to TMEM
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS = thr_store.partition_D(tStS0)
# Copy s_data to register, then to TMEM
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, Float32)
# Load from global → register → TMEM
gS = cute.local_tile(s_data, (128, 128), (0, 0))
tCgS = qk_thr.partition_C(gS)
# Simple: load from global to register using universal copy
cute.copy(tiled_tmem_store, tTMEM_LOADrS, tTMEM_STOREtS)
cute.arch.fence_view_async_tmem_store()
# Actually, we need to first put the data into TMEM.
# The simplest approach: just write identity values to TMEM directly.
# For testing, write the coordinate (m + 128*k) as the S value.
# Then verify sP has the same values.
# Fill TMEM with test values: S[m, k] = m + 128*k (as BF16)
# Use the register bridge pattern
rS_words = cute.make_rmem_tensor(tTMEM_STOREtS.shape, Float32)
rS_bf16 = cute.make_tensor(cute.recast_ptr(rS_words.iterator, dtype=BFloat16), tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
rS_bf16_frg = cute.logical_divide(rS_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
# Get (m, k) coordinate
coord = tTMEM_LOADcS[(k, 0), j, 0, 0]
m_val = coord[0]
k_val = coord[1]
# S = m + 128*k (unique value per position)
val = Float32(1) * m_val + Float32(128) * k_val
tTMEM_LOADrS_frg[k, j] = val
s_vec = tTMEM_LOADrS_frg[None, j].load()
rS_bf16_frg[None, j].store(s_vec.to(BFloat16))
# Store to TMEM
cute.copy(tiled_tmem_store, rS_words, tTMEM_STOREtS)
cute.arch.fence_view_async_tmem_store()
# Now read S from TMEM and write to sP
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Write to sP using coordinate-indexed store
rP_bf16_test = cute.make_tensor(cute.recast_ptr(rS_words.iterator, dtype=BFloat16), tTMEM_LOADrS.layout)
# Copy the BF16 values from S load to P buffer
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_test_frg = cute.logical_divide(rP_bf16_test, cute.make_layout(frg_tile))
rP_bf16_test_frg[None, j].store(s_vec.to(BFloat16))
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16_test[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# Signal MMA warp that sP is ready
sync_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32)
sync_bar.arrive()
# MMA warp: read from sP and write to output
if warp_idx == 4:
sync_bar.arrive_and_wait()
# Read sP using PV MMA's A-operand fragment
pv_thr = pv_mma.get_slice(0)
tCrP = pv_mma.make_fragment_A(sP)
# Read sP values and write to output
gOut = cute.local_tile(sP_out, (128, 128), (0, 0))
tCgOut = pv_thr.partition_C(gOut)
# Simple: read from sP using the fragment and store to output
# We need to iterate over the fragment's K dimension
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
for i in cutlass.range(cute.size(tCrP, mode=[0]), vectorize=True):
pass # Can't easily extract individual values from SMEM fragment
# Alternative: read sP directly using the sP tensor (not fragment)
# This tests if the sP writes are correct
for m in range(128):
for k0 in range(16):
for k1 in range(4):
for k2 in range(2):
val = sP_nostage[(m, k0), 0, (k1, k2)]
# Expected: m + 128*(k0 + 16*k1 + 64*k2) = m + 128*k
# But we can't write to global memory from here easily
if warp_idx < 4:
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def main():
head_dim = 256
s_k = 128
# This is too complex for a first test. Let me do something simpler:
# Just verify that the coordinate mapping (j0, j1) -> (m, k) is correct
# by printing coordinates from the identity tensor.
# Actually, we can't print from inside @cute.kernel easily.
# Let me try a different approach: create a simple test that
# writes known values to sP using the coordinate approach,
# then reads them back and checks correctness.
# Simplest possible test: write to sP in host code and read in kernel
pass
if __name__ == '__main__':
main()