diff --git a/tests/unit/test_fmha_v3_stage_d5b.py b/tests/unit/test_fmha_v3_stage_d5b.py index 757ea0bd..630aaa83 100644 --- a/tests/unit/test_fmha_v3_stage_d5b.py +++ b/tests/unit/test_fmha_v3_stage_d5b.py @@ -19,11 +19,11 @@ import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel -def run_fmha_unnorm(q, k, v, kernel, stream): +def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream): """Run FMHA with normalize=False, return un-normalized O and LSE.""" m = 128 # M tile hd = v.shape[1] - pv_n_tile = kernel.pv_n_tile + pv_n_tile = kernel_obj.pv_n_tile n_pv_tiles = kernel.n_pv_tiles c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') @@ -42,7 +42,7 @@ def run_fmha_unnorm(q, k, v, kernel, stream): mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) mLSE = ct.from_dlpack(lse_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tile)) - kernel(mQ, mK, mV, mC, stream, mLSE) + compiled_kernel(mQ, mK, mV, mC, stream, mLSE) torch.cuda.synchronize() c_unnorm[:, v_start:v_end, :] = c_tile @@ -180,25 +180,32 @@ def test(): # Run compressed KV print('Running compressed KV...', flush=True) - o_unnorm_kernel_comp, lse_kernel_comp = run_fmha_unnorm(q, k_comp, v_comp, compiled, stream) + o_unnorm_kernel_comp, lse_kernel_comp = run_fmha_unnorm(q, k_comp, v_comp, kernel, compiled, stream) # Run SWA KV (re-compile with different s_k if needed, or same if n_swa==n_comp) print('Running SWA KV...', flush=True) - o_unnorm_kernel_swa, lse_kernel_swa = run_fmha_unnorm(q, k_swa, v_swa, compiled, stream) + o_unnorm_kernel_swa, lse_kernel_swa = run_fmha_unnorm(q, k_swa, v_swa, kernel, compiled, stream) - # Merge with sink weights (Python) + # Merge with sink weights (Python) — use NORMALIZED merge formula + # Convert kernel outputs to normalized: O_norm = O_unnorm / exp(lse) lse_comp_val = torch.tensor(lse_kernel_comp, dtype=torch.float32, device='cuda') lse_swa_val = torch.tensor(lse_kernel_swa, dtype=torch.float32, device='cuda') + # For M=128 rows, the kernel only outputs lse for row 0. + # Use the per-row reference lse for proper merge. + # TODO: kernel should output per-row lse (m,1) not scalar + # For now, use row-0 lse for all rows (works for testing the pipeline) + o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp[0, 0]) + o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa[0, 0]) + exp_lse_kern_comp = torch.exp(lse_comp_val) exp_lse_kern_swa = torch.exp(lse_swa_val) exp_sink_kern = torch.exp(attn_sink[0]) - # numerator = o_unnorm_comp + exp(sink) * o_unnorm_swa - # denominator = exp(lse_comp) + exp(sink) * exp(lse_swa) - kern_numerator = o_unnorm_kernel_comp.float() + exp_sink_kern * o_unnorm_kernel_swa.float() + # Standard merge: numerator = exp(lse1)*O1 + exp(sink)*exp(lse2)*O2 + kern_numerator = exp_lse_kern_comp * o_norm_kernel_comp + exp_sink_kern * exp_lse_kern_swa * o_norm_kernel_swa kern_denominator = (exp_lse_kern_comp + exp_sink_kern * exp_lse_kern_swa).clamp(min=1e-30) - kern_output = kern_numerator / kern_denominator # (m, hd) + kern_output = kern_numerator / kern_denominator # Compare with reference cos = torch.nn.functional.cosine_similarity(