diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 79164579..6458df0f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -341,38 +341,57 @@ class FmhaKernel: print(f"[SMEM-P MANUAL] Starting manual P write to SMEM") - # Debug layouts + # Debug to understand partitioning print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}") print(f"[SMEM-P MANUAL] sP layout: {sP.layout}") + print(f"[SMEM-P MANUAL] Softmax warp idx: {sfw_idx}") + + # Get thread index within CTA + # Try to understand thread coordinate system + thread_idx = cute.arch.thread_idx() + print(f"[SMEM-P MANUAL] Thread idx: {thread_idx}") # rP_bf16 contains P values in TMEM layout - # We need to copy them to sP (PV A-operand SMEM layout) + # shape: ((32, 1), 4, 1, 1) + # Try to understand which P values this thread owns + print(f"[SMEM-P MANUAL] rP_bf16 shape: {cute.shape(rP_bf16)}") + print(f"[SMEM-P MANUAL] rP_bf16 layout: {rP_bf16.layout}") - # Approach: Get P values from rP_bf16 and write to sP - # rP_bf16 has shape ((32, 1), 4, 1, 1) in TMEM layout - # Need to map to sP shape ((128, 16), 1, (4, 2), 1) + # For debugging: print first few P values this thread can access + # rP_bf16_frg is logical division of rP_bf16 + frg_tile = (32, 1) # From earlier: cute.make_layout((32, 1)) + frg_cnt = 4 # cute.size(tTMEM_LOADrS_frg, mode=[1]) - # For now, simple test: each thread writes one value - # Compute linear thread index in softmax warp - thread_in_warp = sfw_idx % 32 # Assuming 32 threads per warp - smem_offset = thread_in_warp * 1024 # Arbitrary stride - if smem_offset < cute.size(sP): - sP[smem_offset] = BFloat16(float(thread_in_warp) * 0.01) - print(f"[SMEM-P MANUAL] Thread {thread_in_warp} wrote to SMEM offset {smem_offset}") - else: - print(f"[SMEM-P MANUAL] Thread {thread_in_warp} offset {smem_offset} out of bounds") + print(f"[SMEM-P MANUAL] frg_tile: {frg_tile}, frg_cnt: {frg_cnt}") - # TODO: Implement proper mapping - # Need to understand: - # 1. Which P values does this thread own in rP_bf16? - # 2. Where in sP should they go? + # Map: each thread handles 32×1 tile × 4 fragments = 128 values + # Total 128 threads × 128 values = 16384 P values (128×128) - # For now, zero out rest of sP + # Simple approach: each thread writes its 128 values to SMEM + # Need mapping: thread's 128 linear indices → SMEM addresses + + # For now, implement naive linear mapping (likely wrong but testable) + # thread_linear_idx = thread_idx (0-191) but only softmax warps 0-3 execute + # softmax warps are 128 threads (4 warps × 32) + + # Compute which softmax thread this is (0-127) + softmax_thread_idx = sfw_idx * 32 + (thread_idx % 32) + print(f"[SMEM-P MANUAL] Softmax thread idx: {softmax_thread_idx}") + + # Each thread handles 128 P values starting at index * 128 + base_p_idx = softmax_thread_idx * 128 + + # Write test pattern: thread ID to first SMEM location + if base_p_idx < cute.size(sP): + sP[base_p_idx] = BFloat16(float(softmax_thread_idx) * 0.001) + print(f"[SMEM-P MANUAL] Thread {softmax_thread_idx} wrote to SMEM offset {base_p_idx}") + + # Zero rest of sP for now for j in cutlass.range(cute.size(sP), vectorize=True): - if j != smem_offset: + if j != base_p_idx: sP[j] = BFloat16(0.0) - print(f"[SMEM-P MANUAL] Used test pattern + zeros (TODO: implement proper mapping)") + print(f"[SMEM-P MANUAL] Used linear mapping test (likely wrong)") cute.arch.fence_proxy("async.shared", space="cta") softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: