200 lines
7.9 KiB
Python
200 lines
7.9 KiB
Python
"""
|
||
SMEM-P Write Diagnostic: Verify the coordinate-indexed write to sP.
|
||
|
||
This test:
|
||
1. Creates an S matrix (identity for simplicity)
|
||
2. Writes S values to sP using the coordinate-indexed approach
|
||
3. Reads sP back via the PV MMA's A-operand fragment
|
||
4. Verifies the round-trip
|
||
|
||
If the round-trip is correct, the coordinate extraction and sP indexing are right.
|
||
If not, we need to debug the coordinate mapping.
|
||
"""
|
||
import torch, math
|
||
import cutlass, cutlass.cute as cute
|
||
import cutlass.utils as utils
|
||
from cutlass.cute.nvgpu import tcgen05, cpasync
|
||
from cutlass import Float32, BFloat16, Int32, const_expr
|
||
from cutlass.utils import LayoutEnum
|
||
import cutlass.torch as ct
|
||
import cuda.bindings.driver as cuda
|
||
|
||
|
||
@cute.jit
|
||
def smem_p_write_test(s_data, sP_out, qk_mma, pv_mma, qk_mma_tiler, pv_mma_tiler, p_smem_s):
|
||
"""Write S values to sP and read them back."""
|
||
tidx, _, _ = cute.arch.thread_idx()
|
||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||
|
||
# Only use 4 softmax warps + 1 MMA warp for this test
|
||
if warp_idx >= 5:
|
||
return
|
||
|
||
# SMEM allocation
|
||
@cute.struct
|
||
class SS:
|
||
dummy: cute.struct.MemRange[cutlass.Int64, 2]
|
||
smem = utils.SmemAllocator()
|
||
st = smem.allocate(SS)
|
||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * 5) # 5 warps
|
||
|
||
# Allocate sP in SMEM
|
||
sP = smem.allocate_tensor(element_type=BFloat16, layout=p_smem_s.outer, byte_alignment=128, swizzle=p_smem_s.inner)
|
||
sP_nostage = sP[(None, None, None, 0)]
|
||
|
||
# Allocate TMEM for S
|
||
tmem = utils.TmemAllocator(st.dummy.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=0, is_two_cta=False)
|
||
if warp_idx < 4:
|
||
tmem.allocate(128) # enough for 128×128 FP32 S
|
||
tmem.wait_for_alloc()
|
||
tmem_ptr = tmem.retrieve_ptr(Float32)
|
||
|
||
# QK C-fragment (S in TMEM)
|
||
qk_thr = qk_mma.get_slice(0)
|
||
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||
tStS = qk_thr.make_fragment_C(qk_as)
|
||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||
|
||
# TMEM-load copy
|
||
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||
sfw_idx = tidx % 128 # 4 softmax warps
|
||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||
|
||
# Coordinate identity tensor
|
||
cS = cute.make_identity_tensor((128, 128))
|
||
tScS = qk_thr.partition_C(cS)
|
||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||
|
||
# Softmax warps: read S from TMEM, write to sP
|
||
if warp_idx < 4:
|
||
# First: copy input data to TMEM (S data)
|
||
# Use the TMEM store to write s_data to TMEM
|
||
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS0)
|
||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||
tTMEM_STOREtS = thr_store.partition_D(tStS0)
|
||
|
||
# Copy s_data to register, then to TMEM
|
||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, Float32)
|
||
|
||
# Load from global → register → TMEM
|
||
gS = cute.local_tile(s_data, (128, 128), (0, 0))
|
||
tCgS = qk_thr.partition_C(gS)
|
||
|
||
# Simple: load from global to register using universal copy
|
||
cute.copy(tiled_tmem_store, tTMEM_LOADrS, tTMEM_STOREtS)
|
||
cute.arch.fence_view_async_tmem_store()
|
||
|
||
# Actually, we need to first put the data into TMEM.
|
||
# The simplest approach: just write identity values to TMEM directly.
|
||
# For testing, write the coordinate (m + 128*k) as the S value.
|
||
# Then verify sP has the same values.
|
||
|
||
# Fill TMEM with test values: S[m, k] = m + 128*k (as BF16)
|
||
# Use the register bridge pattern
|
||
rS_words = cute.make_rmem_tensor(tTMEM_STOREtS.shape, Float32)
|
||
rS_bf16 = cute.make_tensor(cute.recast_ptr(rS_words.iterator, dtype=BFloat16), tTMEM_LOADrS.layout)
|
||
|
||
frg_cnt = 4
|
||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||
rS_bf16_frg = cute.logical_divide(rS_bf16, cute.make_layout(frg_tile))
|
||
|
||
for j in range(frg_cnt):
|
||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||
# Get (m, k) coordinate
|
||
coord = tTMEM_LOADcS[(k, 0), j, 0, 0]
|
||
m_val = coord[0]
|
||
k_val = coord[1]
|
||
# S = m + 128*k (unique value per position)
|
||
val = Float32(1) * m_val + Float32(128) * k_val
|
||
tTMEM_LOADrS_frg[k, j] = val
|
||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||
rS_bf16_frg[None, j].store(s_vec.to(BFloat16))
|
||
|
||
# Store to TMEM
|
||
cute.copy(tiled_tmem_store, rS_words, tTMEM_STOREtS)
|
||
cute.arch.fence_view_async_tmem_store()
|
||
|
||
# Now read S from TMEM and write to sP
|
||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||
cute.arch.fence_view_async_tmem_load()
|
||
|
||
# Write to sP using coordinate-indexed store
|
||
rP_bf16_test = cute.make_tensor(cute.recast_ptr(rS_words.iterator, dtype=BFloat16), tTMEM_LOADrS.layout)
|
||
# Copy the BF16 values from S load to P buffer
|
||
for j in range(frg_cnt):
|
||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||
rP_bf16_test_frg = cute.logical_divide(rP_bf16_test, cute.make_layout(frg_tile))
|
||
rP_bf16_test_frg[None, j].store(s_vec.to(BFloat16))
|
||
|
||
for j0 in range(32):
|
||
for j1 in range(4):
|
||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||
m_coord = coord[0]
|
||
k_coord = coord[1]
|
||
k0 = k_coord % 16
|
||
k1 = (k_coord // 16) % 4
|
||
k2 = k_coord // 64
|
||
sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16_test[(j0, 0), j1, 0, 0]
|
||
|
||
cute.arch.fence_proxy("async.shared", space="cta")
|
||
|
||
# Signal MMA warp that sP is ready
|
||
sync_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32)
|
||
sync_bar.arrive()
|
||
|
||
# MMA warp: read from sP and write to output
|
||
if warp_idx == 4:
|
||
sync_bar.arrive_and_wait()
|
||
|
||
# Read sP using PV MMA's A-operand fragment
|
||
pv_thr = pv_mma.get_slice(0)
|
||
tCrP = pv_mma.make_fragment_A(sP)
|
||
|
||
# Read sP values and write to output
|
||
gOut = cute.local_tile(sP_out, (128, 128), (0, 0))
|
||
tCgOut = pv_thr.partition_C(gOut)
|
||
|
||
# Simple: read from sP using the fragment and store to output
|
||
# We need to iterate over the fragment's K dimension
|
||
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
|
||
for i in cutlass.range(cute.size(tCrP, mode=[0]), vectorize=True):
|
||
pass # Can't easily extract individual values from SMEM fragment
|
||
|
||
# Alternative: read sP directly using the sP tensor (not fragment)
|
||
# This tests if the sP writes are correct
|
||
for m in range(128):
|
||
for k0 in range(16):
|
||
for k1 in range(4):
|
||
for k2 in range(2):
|
||
val = sP_nostage[(m, k0), 0, (k1, k2)]
|
||
# Expected: m + 128*(k0 + 16*k1 + 64*k2) = m + 128*k
|
||
# But we can't write to global memory from here easily
|
||
|
||
if warp_idx < 4:
|
||
tmem.relinquish_alloc_permit()
|
||
tmem.free(tmem_ptr)
|
||
|
||
|
||
def main():
|
||
head_dim = 256
|
||
s_k = 128
|
||
|
||
# This is too complex for a first test. Let me do something simpler:
|
||
# Just verify that the coordinate mapping (j0, j1) -> (m, k) is correct
|
||
# by printing coordinates from the identity tensor.
|
||
|
||
# Actually, we can't print from inside @cute.kernel easily.
|
||
# Let me try a different approach: create a simple test that
|
||
# writes known values to sP using the coordinate approach,
|
||
# then reads them back and checks correctness.
|
||
|
||
# Simplest possible test: write to sP in host code and read in kernel
|
||
pass
|
||
|
||
if __name__ == '__main__':
|
||
main()
|