From b77ad244a26e27f0e415d875d891d540fe9e77a3 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 21:35:40 +0000 Subject: [PATCH] D5b: Use normalized O + LSE for merge (correct formula), always output LSE --- dsv4/kernels/attention/fmha.py | 12 +++--- tests/unit/test_fmha_v3_stage_d5b.py | 56 ++++++++++++++-------------- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 5a6e816d..7f9d8f01 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -444,12 +444,12 @@ class FmhaKernel: ) c_pipe.producer_tail() - # D5a: Write LSE (log-softmax) when normalize=False - # lse = ln(row_sum) + attn_max - # row_max is in the scale_log2 domain: max(S * scale * log2(e)) - # attn_max = row_max * ln(2) (converting log2 domain to natural log domain) - # So lse = ln(row_sum) + row_max * ln(2) - if const_expr(not self.normalize): + # D5a: Write LSE (log-softmax) — always when mLSE is provided + # lse = ln(row_sum) + row_max * ln(2) + # This is needed for the SWA+sink merge formula: + # numerator = exp(lse1) * O1_norm + exp(sink) * exp(lse2) * O2_norm + # denominator = exp(lse1) + exp(sink) * exp(lse2) + if mLSE is not None: _row_max_safe = row_max if row_max == -cutlass.Float32.inf: _row_max_safe = Float32(0.0) diff --git a/tests/unit/test_fmha_v3_stage_d5b.py b/tests/unit/test_fmha_v3_stage_d5b.py index 81dbbb2f..e84fc7c9 100644 --- a/tests/unit/test_fmha_v3_stage_d5b.py +++ b/tests/unit/test_fmha_v3_stage_d5b.py @@ -19,14 +19,14 @@ import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel -def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream): - """Run FMHA with normalize=False, return un-normalized O and LSE.""" +def run_fmha(q, k, v, kernel_obj, compiled_kernel, stream): + """Run FMHA (normalize=True) with LSE output, return normalized O and LSE.""" m = 128 # M tile hd = v.shape[1] pv_n_tile = kernel_obj.pv_n_tile n_pv_tiles = kernel_obj.n_pv_tiles - c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') for nt in range(n_pv_tiles): @@ -45,13 +45,13 @@ def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream): compiled_kernel(mQ, mK, mV, mC, stream, mLSE) torch.cuda.synchronize() - c_unnorm[:, v_start:v_end, :] = c_tile + c_out[:, v_start:v_end, :] = c_tile if nt == 0: lse_tensor = lse_tile - o_unnorm = c_unnorm[:, :, 0] # (m, hd) - lse = lse_tensor[0, 0, 0].item() # scalar (M=1 decode) - return o_unnorm, lse + o_norm = c_out[:, :, 0] # (m, hd) — normalized + lse = lse_tensor[0, 0, 0].item() # scalar (row 0) + return o_norm, lse def test(): @@ -131,9 +131,9 @@ def test(): ).item() print(f" Row {i}: norm_vs_unnorm cos = {row_cos:.6f}") - # === Kernel: Run FMHA twice (normalize=False) and merge === + # === Kernel: Run FMHA (normalize=True) with LSE and merge === stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - kernel = FmhaKernel(head_dim=hd, s_k=n_comp, normalize=False) + kernel = FmhaKernel(head_dim=hd, s_k=n_comp) # normalize=True (default) # Compile print('Compiling kernel...', flush=True) @@ -149,26 +149,24 @@ def test(): # Run compressed KV print('Running compressed KV...', flush=True) - o_unnorm_kernel_comp, lse_kernel_comp = run_fmha_unnorm(q, k_comp, v_comp, kernel, compiled, stream) + o_kernel_comp, lse_kernel_comp = run_fmha(q, k_comp, v_comp, kernel, compiled, stream) - # Run SWA KV (re-compile with different s_k if needed, or same if n_swa==n_comp) + # Run SWA KV print('Running SWA KV...', flush=True) - o_unnorm_kernel_swa, lse_kernel_swa = run_fmha_unnorm(q, k_swa, v_swa, kernel, compiled, stream) + o_kernel_swa, lse_kernel_swa = run_fmha(q, k_swa, v_swa, kernel, compiled, stream) - # Merge with sink weights (Python) — use stable normalized merge formula - lse_comp_per_row = lse_comp[:, 0] # (m,) — reference per-row LSE - lse_swa_per_row = lse_swa[:, 0] # (m,) — reference per-row LSE - o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp_per_row.unsqueeze(1)) - o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa_per_row.unsqueeze(1)) - - # Stable merge (same as reference) - lse_max_kern = torch.max(lse_comp_per_row.unsqueeze(1), lse_swa_per_row.unsqueeze(1)) - exp_lse_comp_kern = torch.exp(lse_comp_per_row.unsqueeze(1) - lse_max_kern) - exp_lse_swa_kern = torch.exp(lse_swa_per_row.unsqueeze(1) - lse_max_kern) + # Merge with sink weights using standard formula: + # numerator = exp(lse1) * O1_norm + exp(sink) * exp(lse2) * O2_norm + # denominator = exp(lse1) + exp(sink) * exp(lse2) + lse_comp_val = torch.tensor(lse_kernel_comp, dtype=torch.float32, device='cuda') + lse_swa_val = torch.tensor(lse_kernel_swa, dtype=torch.float32, device='cuda') + exp_lse_kern_comp = torch.exp(lse_comp_val) + exp_lse_kern_swa = torch.exp(lse_swa_val) exp_sink_kern = torch.exp(attn_sink[0]) - kern_numerator = exp_lse_comp_kern * o_norm_kernel_comp + exp_sink_kern * exp_lse_swa_kern * o_norm_kernel_swa - kern_denominator = (exp_lse_comp_kern + exp_sink_kern * exp_lse_swa_kern).clamp(min=1e-30) + # Using kernel's scalar LSE (row 0 only) for all rows + kern_numerator = exp_lse_kern_comp * o_kernel_comp.float() + exp_sink_kern * exp_lse_kern_swa * o_kernel_swa.float() + kern_denominator = (exp_lse_kern_comp + exp_sink_kern * exp_lse_kern_swa).clamp(min=1e-30) kern_output = kern_numerator / kern_denominator # Compare with reference @@ -184,14 +182,14 @@ def test(): print(f' kern[0,:4]={kern_output[0,:4].tolist()}') print(f' ref[0,:4]={ref_merge[0,:4].tolist()}') - # Also check individual attention passes + # Also check individual attention passes (normalized O) cos_comp = torch.nn.functional.cosine_similarity( - o_unnorm_kernel_comp.flatten().unsqueeze(0).float(), - o_unnorm_comp.flatten().unsqueeze(0) + o_kernel_comp.flatten().unsqueeze(0).float(), + o_norm_comp.flatten().unsqueeze(0) ).item() cos_swa = torch.nn.functional.cosine_similarity( - o_unnorm_kernel_swa.flatten().unsqueeze(0).float(), - o_unnorm_swa.flatten().unsqueeze(0) + o_kernel_swa.flatten().unsqueeze(0).float(), + o_norm_swa.flatten().unsqueeze(0) ).item() print(f' Compressed KV unnorm cos: {cos_comp:.6f}') print(f' SWA KV unnorm cos: {cos_swa:.6f}')