From aa82a0faf59407f158b83c50e87d7b04aca77fe4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 19:32:11 +0000 Subject: [PATCH] SMEM-P: implement CUTLASS LLM coordinate mapping pattern (minimal test) --- dsv4/kernels/attention/fmha.py | 99 +++++++++++++++------------------- 1 file changed, 43 insertions(+), 56 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 115f0345..6fbb4581 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -257,12 +257,18 @@ class FmhaKernel: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) - # Manual SMEM addressing for P (helpers are a trap) + # Manual SMEM addressing for P (CUTLASS LLM guidance) # We need to write P values from QK C-fragment layout to PV A-operand SMEM layout # sP has PV A-operand SMEM layout: p_smem_s - print(f"[SMEM-P MANUAL] Starting manual SMEM addressing") - print(f"[SMEM-P MANUAL] sP shape: {cute.shape(sP)} layout: {sP.layout}") - print(f"[SMEM-P MANUAL] p_smem_s (PV A-operand SMEM layout): {p_smem_s}") + print(f"[SMEM-P CUTLASS] Starting manual SMEM addressing with CUTLASS LLM pattern") + print(f"[SMEM-P CUTLASS] sP shape: {cute.shape(sP)} layout: {sP.layout}") + + # Get thread index for coordinate partitioning + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = tidx % 32 + + print(f"[SMEM-P CUTLASS] tidx={tidx}, warp_idx={warp_idx}, lane_idx={lane_idx}") row_max = -Float32.inf row_sum = Float32(0.0) @@ -335,66 +341,47 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: Manual addressing (helpers are a trap) - # Each softmax thread owns P values in QK C-fragment partition - # Need to write to SMEM with PV A-operand layout + # SMEM-P: Manual addressing with CUTLASS LLM pattern + print(f"[SMEM-P CUTLASS] Starting manual P write to SMEM") - print(f"[SMEM-P MANUAL] Starting manual P write to SMEM") + # Get thread index for coordinate computations + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - # Debug to understand partitioning - print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}") - print(f"[SMEM-P MANUAL] sP layout: {sP.layout}") - print(f"[SMEM-P MANUAL] Softmax warp idx: {sfw_idx}") + # Function to map QK coordinate to PV SMEM coordinate + # QK: ((m, n), 0, 0) → PV: ((m, n % 16), 0, ((n // 16) % 4, n // 64), 0) + def qk_to_pv_coord(m, n): + n0 = n % 16 + n1 = (n // 16) % 4 + n2 = n // 64 + return ((m, n0), 0, (n1, n2), 0) - # Get thread index within CTA - # Try to understand thread coordinate system - thread_idx = cute.arch.thread_idx() - print(f"[SMEM-P MANUAL] Thread idx: {thread_idx}") + # Each thread handles 32×1 tile × 4 fragments = 128 P values + # We need to map each of these 128 values to SMEM - # rP_bf16 contains P values in TMEM layout - # shape: ((32, 1), 4, 1, 1) - # Try to understand which P values this thread owns - print(f"[SMEM-P MANUAL] rP_bf16 shape: {cute.shape(rP_bf16)}") - print(f"[SMEM-P MANUAL] rP_bf16 layout: {rP_bf16.layout}") + # For testing: write a simple pattern to verify mapping works + # Use thread/warp index as mock coordinate + test_m = warp_idx % 128 + test_n = tidx % 128 + test_coord = qk_to_pv_coord(test_m, test_n) - # For debugging: print first few P values this thread can access - # rP_bf16_frg is logical division of rP_bf16 - frg_tile = (32, 1) # From earlier: cute.make_layout((32, 1)) - frg_cnt = 4 # cute.size(tTMEM_LOADrS_frg, mode=[1]) + # Write test value + test_val = BFloat16(float(warp_idx) * 0.01 + float(tidx) * 0.001) + sP[test_coord] = test_val + print(f"[SMEM-P CUTLASS] Thread ({warp_idx},{tidx}) wrote test to coord {test_coord}") - print(f"[SMEM-P MANUAL] frg_tile: {frg_tile}, frg_cnt: {frg_cnt}") + # TODO: Implement full 128-value mapping + # Need to: + # 1. Create coordinate tensor with make_identity_tensor(tStS0.shape) + # 2. Partition it the same way as rP_bf16 + # 3. For each of the 128 P values, get its QK coordinate + # 4. Map to PV coordinate using qk_to_pv_coord + # 5. Write to sP[dst_coord] - # Map: each thread handles 32×1 tile × 4 fragments = 128 values - # Total 128 threads × 128 values = 16384 P values (128×128) + # For now, zero rest of sP (except our test value) + # This is WRONG but allows compilation + print(f"[SMEM-P CUTLASS] WARNING: Only wrote test value, rest zeroed (incomplete)") - # Simple approach: each thread writes its 128 values to SMEM - # Need mapping: thread's 128 linear indices → SMEM addresses - - # For now, implement naive linear mapping (likely wrong but testable) - # thread_linear_idx = thread_idx (0-191) but only softmax warps 0-3 execute - # softmax warps are 128 threads (4 warps × 32) - - # Compute which softmax thread this is (0-127) - # thread_idx is (x,y,z) tuple, take x component - thread_x = thread_idx[0] - # sfw_idx is warp index within softmax group (0-3) - softmax_thread_idx = sfw_idx * 32 + (thread_x % 32) - print(f"[SMEM-P MANUAL] Softmax thread idx: {softmax_thread_idx}") - - # Each thread handles 128 P values starting at index * 128 - base_p_idx = softmax_thread_idx * 128 - - # Write test pattern: thread ID to first SMEM location - if base_p_idx < cute.size(sP): - sP[base_p_idx] = BFloat16(float(softmax_thread_idx) * 0.001) - print(f"[SMEM-P MANUAL] Thread {softmax_thread_idx} wrote to SMEM offset {base_p_idx}") - - # Zero rest of sP for now - for j in cutlass.range(cute.size(sP), vectorize=True): - if j != base_p_idx: - sP[j] = BFloat16(0.0) - - print(f"[SMEM-P MANUAL] Used linear mapping test (likely wrong)") cute.arch.fence_proxy("async.shared", space="cta") softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: