From 638e8b862bb044a77d031fb14bb8f7c3b5933b17 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 19:59:52 +0000 Subject: [PATCH] SMEM-P: implement CUTLASS LLM fixes - dynamic frg_tile, local coordinate conversion --- dsv4/kernels/attention/fmha.py | 42 ++++++++++++++-------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 5f4dbc47..23905d22 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -315,11 +315,16 @@ class FmhaKernel: old_row_max = row_max frg_cnt = 4 frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt - tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + # Compute fragment tile size dynamically (must match value division) + frg_tile_size = cute.size(tTMEM_LOADrS) // frg_cnt + frg_layout = cute.make_layout(frg_tile_size) + + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, frg_layout) # Coordinate fragments for SMEM-P mapping (needed unconditionally for scoping) - tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(frg_tile)) + tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, frg_layout) if self.use_smem_p: print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}") + print(f"[SMEM-P CUTLASS] frg_tile_size: {frg_tile_size}, frg_layout: {frg_layout}") for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): @@ -354,32 +359,19 @@ class FmhaKernel: n = qk_coord[1] # Map to PV SMEM coordinate - # Try transposed mapping (maybe PV expects P^T?) - m0 = m % 16 - m1 = (m // 16) % 4 - m2 = m // 64 - pv_coord = ((n, m0), 0, (m1, m2), 0) + # Convert to local coordinates (0-127) as sanity check + m_local = m % 128 + n_local = n % 128 - # Original mapping (likely wrong): - # n0 = n % 16 - # n1 = (n // 16) % 4 - # n2 = n // 64 - # pv_coord = ((m, n0), 0, (n1, n2), 0) + # Original mapping formula (should be correct for local coords) + n0 = n_local % 16 + n1 = (n_local // 16) % 4 + n2 = n_local // 64 + pv_coord = ((m_local, n0), 0, (n1, n2), 0) - # DEBUG: Write linear index as value: m*128 + n - # This uniquely identifies each position - linear_idx = m * 128 + n - # Convert to Float32 (values 0-16383) - pattern_val = Float32(linear_idx) - p_val_bf16 = pattern_val.to(self.q_dtype) - # Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) - - # Try both tensor indexing AND manual offset for debugging + # Write actual P value (not test pattern) + p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) sP[pv_coord] = p_val_bf16 # Tensor indexing - - # Also compute manual offset to verify - # offset = cute.crd2idx(pv_coord, sP.layout) - # (sP.iterator + offset) = p_val_bf16 row_sum = row_sum + tTMEM_LOADrS_frg[k, j] s_vec = tTMEM_LOADrS_frg[None, j].load()