test: SMEM-P coordinate verification test

This commit is contained in:
2026-05-24 01:58:32 +00:00
parent 394f08601a
commit f2d95da4aa
3 changed files with 399 additions and 19 deletions

View File

@@ -282,28 +282,21 @@ class FmhaKernel:
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P
# Per CUTLASS LLM guidance: use make_cotiled_copy with TV layout
# from TMEM load partition, remapped to sP's codomain.
# atom_layout_tv: (tid, vid) -> sP address
# data_layout: sP coord -> sP address (includes swizzle)
# Strategy: Use make_cotiled_copy with atom_layout_tv built from
# the TMEM-load coordinate partition + sP address mapping.
#
# Build the TV layout from the TMEM load, remapped to sP's codomain.
# The TMEM load's TV layout maps (tid, vid) -> tStS_addr.
# tStS layout: ((128,128),1,1):((65536,1),0,0) => addr = m*65536 + k
# sP_stage layout: ((128,16),1,(4,2)):((64,1),0,(16,8192)) + swizzle S<3,4,3>
# The TMEM-load partition gives each thread (m, k) coordinates via tTMEM_LOADcS.
# We compose these coordinates with sP's logical address layout to get
# (tid, vid) -> sP_addr. Then make_cotiled_copy creates a proper TiledCopy.
#
# We need: (tid, vid) -> sP_addr.
# Approach: use composition(sP_2d, tv_layout) where sP_2d maps
# flat P index -> sP_addr, and we "unflatten" the TV layout's
# tStS addresses into flat P indices.
# Key: sP's outer layout maps (m, k0, k1, k2) -> sP_addr with strides (64, 1, 16, 8192).
# We need to build atom_layout_tv in sP's flat address space, not tStS's.
#
# tStS addr -> flat P index: addr // 65536 * 128 + addr % 65536
# Since k < 128 and stride is 65536, flat_idx = (addr >> 16) * 128 + (addr & 0xFFFF)
# This is NOT affine, so we can't represent it as a Layout.
#
# FALLBACK: Use the coordinate-indexed approach (scalar SMEM writes).
# This works but gives ~0.04 cosine loss vs TMEM-P at hd=64.
# The make_cotiled_copy approach is tracked for future optimization.
# Step 1: Build sP address mapping in the same coordinate system as tStS.
# sP is indexed as ((m, k%16), 0, ((k//16)%4, k//64)) with strides ((64,1),0,(16,8192)).
# In the P matrix's (m, k) coordinate space:
# sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192))
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
row_max = -Float32.inf
@@ -379,6 +372,8 @@ class FmhaKernel:
else:
# SMEM-P: write P to sP using coordinate-indexed store.
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates.
# DEBUG: Write a known pattern to sP to verify the coordinate mapping.
# Pattern: sP[m, k] = (m + k) % 256 as BF16 (unique per position)
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
@@ -387,6 +382,7 @@ class FmhaKernel:
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
# Debug: write (m + k) mod 256 instead of actual P value
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:

View File

@@ -0,0 +1,185 @@
"""
SMEM-P Coordinate Verification Test.
Writes a known pattern to sP using the coordinate-indexed approach
(identical to FmhaKernel's SMEM-P path), then reads sP back
via a simple SMEM→GMEM copy and verifies on the host.
Pattern: sP[m, k] = float(m*128 + k) (unique value per position)
Expected: output[m, k] = float(m*128 + k) after round-trip
If coordinates are correct, all values match.
If coordinates are wrong, values will be at different positions.
"""
import torch, math
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import cuda.bindings.driver as cuda
import cutlass.pipeline as pipeline
@cute.jit
def smem_p_coord_test(
mOut, qk_mma, pv_mma, qk_mma_tiler, pv_mma_tiler, p_smem_s
):
"""Write known pattern to sP using coordinate-indexed approach, read back."""
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
# 5 warps: 4 softmax + 1 for TMEM alloc
if warp_idx >= 5:
return
# SMEM allocation
smem = utils.SmemAllocator()
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * 5)
sP = smem.allocate_tensor(element_type=BFloat16, layout=p_smem_s.outer, byte_alignment=128, swizzle=p_smem_s.inner)
sP_nostage = sP[(None, None, None, 0)]
# TMEM allocation
tmem = utils.TmemAllocator(None, barrier_for_retrieve=tmem_bar, allocator_warp_id=4, is_two_cta=False)
if warp_idx == 4:
tmem.allocate(128)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(Float32)
# QK C-fragment (for TMEM layout)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# TMEM-load copy (same as FmhaKernel)
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
# Softmax warps: write known pattern to sP
if warp_idx < 4:
sfw_idx = tidx % 128 # 4 softmax warps
thr_load = tiled_tmem_load.get_slice(sfw_idx)
# Coordinate identity tensor
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# Write known pattern to sP using coordinate-indexed approach
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
# Pattern: value = (m * 128 + k) as BF16
val = BFloat16(m_coord * 128 + k_coord)
sP_nostage[(m_coord, k0), 0, (k1, k2)] = val
cute.arch.fence_proxy("async.shared", space="cta")
# Barrier between softmax writes and read-back
sync_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 * 5)
sync_bar.arrive_and_wait()
# Read sP back to global memory
# Thread 0 of warp 0 does the read (simple sequential access)
if tidx == 0:
gOut = mOut # (128, 128) output
for m in range(128):
for k in range(128):
k0 = k % 16
k1 = (k // 16) % 4
k2 = k // 64
val = sP_nostage[(m, k0), 0, (k1, k2)]
gOut[m, k] = val
if warp_idx < 4:
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def main():
head_dim = 256
s_k = 128
m = 128
pv_n_tile = min(head_dim, 256)
# Create tensors for layout derivation
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
)
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, v_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), tcgen05.OperandSource.SMEM
)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
# Output tensor
out = torch.zeros(128, 128, dtype=torch.bfloat16, device='cuda')
mOut = ct.from_dlpack(out).mark_layout_dynamic(leading_dim=ct.get_leading_dim(out))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("Compiling...")
compiled = cute.compile(smem_p_coord_test, mOut, qk_mma, pv_mma, qk_mma_tiler, pv_mma_tiler, p_smem_s)
print("Running...")
compiled(mOut, qk_mma, pv_mma, qk_mma_tiler, pv_mma_tiler, p_smem_s, stream)
torch.cuda.synchronize()
# Verify: out[m, k] should be m*128 + k
out_float = out.float()
expected = torch.arange(128, dtype=torch.float32, device='cuda').unsqueeze(0) * 128 + torch.arange(128, dtype=torch.float32, device='cuda').unsqueeze(1)
# Check a few values
print(f"\n=== Verification ===")
n_correct = 0
n_total = 128 * 128
for m in range(128):
for k in range(128):
exp = float(m * 128 + k)
got = out_float[m, k].item()
if abs(got - exp) < 1.0: # BF16 precision
n_correct += 1
elif n_correct > 0 and n_correct < 5:
print(f" MISMATCH at ({m},{k}): expected {exp}, got {got}")
accuracy = n_correct / n_total * 100
print(f" Accuracy: {n_correct}/{n_total} = {accuracy:.1f}%")
# Print a few sample values
print(f"\n Sample values:")
for m in [0, 1, 2, 64, 127]:
for k in [0, 1, 16, 32, 64, 127]:
exp = float(m * 128 + k)
got = out_float[m, k].item()
status = "" if abs(got - exp) < 1.0 else ""
print(f" [{m},{k}] expected={exp:.0f} got={got:.1f} {status}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,199 @@
"""
SMEM-P Write Diagnostic: Verify the coordinate-indexed write to sP.
This test:
1. Creates an S matrix (identity for simplicity)
2. Writes S values to sP using the coordinate-indexed approach
3. Reads sP back via the PV MMA's A-operand fragment
4. Verifies the round-trip
If the round-trip is correct, the coordinate extraction and sP indexing are right.
If not, we need to debug the coordinate mapping.
"""
import torch, math
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05, cpasync
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import cuda.bindings.driver as cuda
@cute.jit
def smem_p_write_test(s_data, sP_out, qk_mma, pv_mma, qk_mma_tiler, pv_mma_tiler, p_smem_s):
"""Write S values to sP and read them back."""
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
# Only use 4 softmax warps + 1 MMA warp for this test
if warp_idx >= 5:
return
# SMEM allocation
@cute.struct
class SS:
dummy: cute.struct.MemRange[cutlass.Int64, 2]
smem = utils.SmemAllocator()
st = smem.allocate(SS)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * 5) # 5 warps
# Allocate sP in SMEM
sP = smem.allocate_tensor(element_type=BFloat16, layout=p_smem_s.outer, byte_alignment=128, swizzle=p_smem_s.inner)
sP_nostage = sP[(None, None, None, 0)]
# Allocate TMEM for S
tmem = utils.TmemAllocator(st.dummy.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=0, is_two_cta=False)
if warp_idx < 4:
tmem.allocate(128) # enough for 128×128 FP32 S
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(Float32)
# QK C-fragment (S in TMEM)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# TMEM-load copy
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
sfw_idx = tidx % 128 # 4 softmax warps
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
# Coordinate identity tensor
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# Softmax warps: read S from TMEM, write to sP
if warp_idx < 4:
# First: copy input data to TMEM (S data)
# Use the TMEM store to write s_data to TMEM
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS = thr_store.partition_D(tStS0)
# Copy s_data to register, then to TMEM
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, Float32)
# Load from global → register → TMEM
gS = cute.local_tile(s_data, (128, 128), (0, 0))
tCgS = qk_thr.partition_C(gS)
# Simple: load from global to register using universal copy
cute.copy(tiled_tmem_store, tTMEM_LOADrS, tTMEM_STOREtS)
cute.arch.fence_view_async_tmem_store()
# Actually, we need to first put the data into TMEM.
# The simplest approach: just write identity values to TMEM directly.
# For testing, write the coordinate (m + 128*k) as the S value.
# Then verify sP has the same values.
# Fill TMEM with test values: S[m, k] = m + 128*k (as BF16)
# Use the register bridge pattern
rS_words = cute.make_rmem_tensor(tTMEM_STOREtS.shape, Float32)
rS_bf16 = cute.make_tensor(cute.recast_ptr(rS_words.iterator, dtype=BFloat16), tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
rS_bf16_frg = cute.logical_divide(rS_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
# Get (m, k) coordinate
coord = tTMEM_LOADcS[(k, 0), j, 0, 0]
m_val = coord[0]
k_val = coord[1]
# S = m + 128*k (unique value per position)
val = Float32(1) * m_val + Float32(128) * k_val
tTMEM_LOADrS_frg[k, j] = val
s_vec = tTMEM_LOADrS_frg[None, j].load()
rS_bf16_frg[None, j].store(s_vec.to(BFloat16))
# Store to TMEM
cute.copy(tiled_tmem_store, rS_words, tTMEM_STOREtS)
cute.arch.fence_view_async_tmem_store()
# Now read S from TMEM and write to sP
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Write to sP using coordinate-indexed store
rP_bf16_test = cute.make_tensor(cute.recast_ptr(rS_words.iterator, dtype=BFloat16), tTMEM_LOADrS.layout)
# Copy the BF16 values from S load to P buffer
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_test_frg = cute.logical_divide(rP_bf16_test, cute.make_layout(frg_tile))
rP_bf16_test_frg[None, j].store(s_vec.to(BFloat16))
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16_test[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# Signal MMA warp that sP is ready
sync_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32)
sync_bar.arrive()
# MMA warp: read from sP and write to output
if warp_idx == 4:
sync_bar.arrive_and_wait()
# Read sP using PV MMA's A-operand fragment
pv_thr = pv_mma.get_slice(0)
tCrP = pv_mma.make_fragment_A(sP)
# Read sP values and write to output
gOut = cute.local_tile(sP_out, (128, 128), (0, 0))
tCgOut = pv_thr.partition_C(gOut)
# Simple: read from sP using the fragment and store to output
# We need to iterate over the fragment's K dimension
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
for i in cutlass.range(cute.size(tCrP, mode=[0]), vectorize=True):
pass # Can't easily extract individual values from SMEM fragment
# Alternative: read sP directly using the sP tensor (not fragment)
# This tests if the sP writes are correct
for m in range(128):
for k0 in range(16):
for k1 in range(4):
for k2 in range(2):
val = sP_nostage[(m, k0), 0, (k1, k2)]
# Expected: m + 128*(k0 + 16*k1 + 64*k2) = m + 128*k
# But we can't write to global memory from here easily
if warp_idx < 4:
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def main():
head_dim = 256
s_k = 128
# This is too complex for a first test. Let me do something simpler:
# Just verify that the coordinate mapping (j0, j1) -> (m, k) is correct
# by printing coordinates from the identity tensor.
# Actually, we can't print from inside @cute.kernel easily.
# Let me try a different approach: create a simple test that
# writes known values to sP using the coordinate approach,
# then reads them back and checks correctness.
# Simplest possible test: write to sP in host code and read in kernel
pass
if __name__ == '__main__':
main()