D5b: Use normalized O + LSE for merge (correct formula), always output LSE

This commit is contained in:
2026-05-23 21:35:40 +00:00
parent 84200ca557
commit b77ad244a2
2 changed files with 33 additions and 35 deletions

View File

@@ -444,12 +444,12 @@ class FmhaKernel:
)
c_pipe.producer_tail()
# D5a: Write LSE (log-softmax) when normalize=False
# lse = ln(row_sum) + attn_max
# row_max is in the scale_log2 domain: max(S * scale * log2(e))
# attn_max = row_max * ln(2) (converting log2 domain to natural log domain)
# So lse = ln(row_sum) + row_max * ln(2)
if const_expr(not self.normalize):
# D5a: Write LSE (log-softmax) — always when mLSE is provided
# lse = ln(row_sum) + row_max * ln(2)
# This is needed for the SWA+sink merge formula:
# numerator = exp(lse1) * O1_norm + exp(sink) * exp(lse2) * O2_norm
# denominator = exp(lse1) + exp(sink) * exp(lse2)
if mLSE is not None:
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)

View File

@@ -19,14 +19,14 @@ import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream):
"""Run FMHA with normalize=False, return un-normalized O and LSE."""
def run_fmha(q, k, v, kernel_obj, compiled_kernel, stream):
"""Run FMHA (normalize=True) with LSE output, return normalized O and LSE."""
m = 128 # M tile
hd = v.shape[1]
pv_n_tile = kernel_obj.pv_n_tile
n_pv_tiles = kernel_obj.n_pv_tiles
c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
for nt in range(n_pv_tiles):
@@ -45,13 +45,13 @@ def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream):
compiled_kernel(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
c_unnorm[:, v_start:v_end, :] = c_tile
c_out[:, v_start:v_end, :] = c_tile
if nt == 0:
lse_tensor = lse_tile
o_unnorm = c_unnorm[:, :, 0] # (m, hd)
lse = lse_tensor[0, 0, 0].item() # scalar (M=1 decode)
return o_unnorm, lse
o_norm = c_out[:, :, 0] # (m, hd) — normalized
lse = lse_tensor[0, 0, 0].item() # scalar (row 0)
return o_norm, lse
def test():
@@ -131,9 +131,9 @@ def test():
).item()
print(f" Row {i}: norm_vs_unnorm cos = {row_cos:.6f}")
# === Kernel: Run FMHA twice (normalize=False) and merge ===
# === Kernel: Run FMHA (normalize=True) with LSE and merge ===
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(head_dim=hd, s_k=n_comp, normalize=False)
kernel = FmhaKernel(head_dim=hd, s_k=n_comp) # normalize=True (default)
# Compile
print('Compiling kernel...', flush=True)
@@ -149,26 +149,24 @@ def test():
# Run compressed KV
print('Running compressed KV...', flush=True)
o_unnorm_kernel_comp, lse_kernel_comp = run_fmha_unnorm(q, k_comp, v_comp, kernel, compiled, stream)
o_kernel_comp, lse_kernel_comp = run_fmha(q, k_comp, v_comp, kernel, compiled, stream)
# Run SWA KV (re-compile with different s_k if needed, or same if n_swa==n_comp)
# Run SWA KV
print('Running SWA KV...', flush=True)
o_unnorm_kernel_swa, lse_kernel_swa = run_fmha_unnorm(q, k_swa, v_swa, kernel, compiled, stream)
o_kernel_swa, lse_kernel_swa = run_fmha(q, k_swa, v_swa, kernel, compiled, stream)
# Merge with sink weights (Python) — use stable normalized merge formula
lse_comp_per_row = lse_comp[:, 0] # (m,) — reference per-row LSE
lse_swa_per_row = lse_swa[:, 0] # (m,) — reference per-row LSE
o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp_per_row.unsqueeze(1))
o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa_per_row.unsqueeze(1))
# Stable merge (same as reference)
lse_max_kern = torch.max(lse_comp_per_row.unsqueeze(1), lse_swa_per_row.unsqueeze(1))
exp_lse_comp_kern = torch.exp(lse_comp_per_row.unsqueeze(1) - lse_max_kern)
exp_lse_swa_kern = torch.exp(lse_swa_per_row.unsqueeze(1) - lse_max_kern)
# Merge with sink weights using standard formula:
# numerator = exp(lse1) * O1_norm + exp(sink) * exp(lse2) * O2_norm
# denominator = exp(lse1) + exp(sink) * exp(lse2)
lse_comp_val = torch.tensor(lse_kernel_comp, dtype=torch.float32, device='cuda')
lse_swa_val = torch.tensor(lse_kernel_swa, dtype=torch.float32, device='cuda')
exp_lse_kern_comp = torch.exp(lse_comp_val)
exp_lse_kern_swa = torch.exp(lse_swa_val)
exp_sink_kern = torch.exp(attn_sink[0])
kern_numerator = exp_lse_comp_kern * o_norm_kernel_comp + exp_sink_kern * exp_lse_swa_kern * o_norm_kernel_swa
kern_denominator = (exp_lse_comp_kern + exp_sink_kern * exp_lse_swa_kern).clamp(min=1e-30)
# Using kernel's scalar LSE (row 0 only) for all rows
kern_numerator = exp_lse_kern_comp * o_kernel_comp.float() + exp_sink_kern * exp_lse_kern_swa * o_kernel_swa.float()
kern_denominator = (exp_lse_kern_comp + exp_sink_kern * exp_lse_kern_swa).clamp(min=1e-30)
kern_output = kern_numerator / kern_denominator
# Compare with reference
@@ -184,14 +182,14 @@ def test():
print(f' kern[0,:4]={kern_output[0,:4].tolist()}')
print(f' ref[0,:4]={ref_merge[0,:4].tolist()}')
# Also check individual attention passes
# Also check individual attention passes (normalized O)
cos_comp = torch.nn.functional.cosine_similarity(
o_unnorm_kernel_comp.flatten().unsqueeze(0).float(),
o_unnorm_comp.flatten().unsqueeze(0)
o_kernel_comp.flatten().unsqueeze(0).float(),
o_norm_comp.flatten().unsqueeze(0)
).item()
cos_swa = torch.nn.functional.cosine_similarity(
o_unnorm_kernel_swa.flatten().unsqueeze(0).float(),
o_unnorm_swa.flatten().unsqueeze(0)
o_kernel_swa.flatten().unsqueeze(0).float(),
o_norm_swa.flatten().unsqueeze(0)
).item()
print(f' Compressed KV unnorm cos: {cos_comp:.6f}')
print(f' SWA KV unnorm cos: {cos_swa:.6f}')