D5b: Use normalized O + LSE for merge (correct formula), always output LSE
This commit is contained in:
@@ -444,12 +444,12 @@ class FmhaKernel:
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
|
||||
# D5a: Write LSE (log-softmax) when normalize=False
|
||||
# lse = ln(row_sum) + attn_max
|
||||
# row_max is in the scale_log2 domain: max(S * scale * log2(e))
|
||||
# attn_max = row_max * ln(2) (converting log2 domain to natural log domain)
|
||||
# So lse = ln(row_sum) + row_max * ln(2)
|
||||
if const_expr(not self.normalize):
|
||||
# D5a: Write LSE (log-softmax) — always when mLSE is provided
|
||||
# lse = ln(row_sum) + row_max * ln(2)
|
||||
# This is needed for the SWA+sink merge formula:
|
||||
# numerator = exp(lse1) * O1_norm + exp(sink) * exp(lse2) * O2_norm
|
||||
# denominator = exp(lse1) + exp(sink) * exp(lse2)
|
||||
if mLSE is not None:
|
||||
_row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
_row_max_safe = Float32(0.0)
|
||||
|
||||
@@ -19,14 +19,14 @@ import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream):
|
||||
"""Run FMHA with normalize=False, return un-normalized O and LSE."""
|
||||
def run_fmha(q, k, v, kernel_obj, compiled_kernel, stream):
|
||||
"""Run FMHA (normalize=True) with LSE output, return normalized O and LSE."""
|
||||
m = 128 # M tile
|
||||
hd = v.shape[1]
|
||||
pv_n_tile = kernel_obj.pv_n_tile
|
||||
n_pv_tiles = kernel_obj.n_pv_tiles
|
||||
|
||||
c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
for nt in range(n_pv_tiles):
|
||||
@@ -45,13 +45,13 @@ def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream):
|
||||
compiled_kernel(mQ, mK, mV, mC, stream, mLSE)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
c_unnorm[:, v_start:v_end, :] = c_tile
|
||||
c_out[:, v_start:v_end, :] = c_tile
|
||||
if nt == 0:
|
||||
lse_tensor = lse_tile
|
||||
|
||||
o_unnorm = c_unnorm[:, :, 0] # (m, hd)
|
||||
lse = lse_tensor[0, 0, 0].item() # scalar (M=1 decode)
|
||||
return o_unnorm, lse
|
||||
o_norm = c_out[:, :, 0] # (m, hd) — normalized
|
||||
lse = lse_tensor[0, 0, 0].item() # scalar (row 0)
|
||||
return o_norm, lse
|
||||
|
||||
|
||||
def test():
|
||||
@@ -131,9 +131,9 @@ def test():
|
||||
).item()
|
||||
print(f" Row {i}: norm_vs_unnorm cos = {row_cos:.6f}")
|
||||
|
||||
# === Kernel: Run FMHA twice (normalize=False) and merge ===
|
||||
# === Kernel: Run FMHA (normalize=True) with LSE and merge ===
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=n_comp, normalize=False)
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=n_comp) # normalize=True (default)
|
||||
|
||||
# Compile
|
||||
print('Compiling kernel...', flush=True)
|
||||
@@ -149,26 +149,24 @@ def test():
|
||||
|
||||
# Run compressed KV
|
||||
print('Running compressed KV...', flush=True)
|
||||
o_unnorm_kernel_comp, lse_kernel_comp = run_fmha_unnorm(q, k_comp, v_comp, kernel, compiled, stream)
|
||||
o_kernel_comp, lse_kernel_comp = run_fmha(q, k_comp, v_comp, kernel, compiled, stream)
|
||||
|
||||
# Run SWA KV (re-compile with different s_k if needed, or same if n_swa==n_comp)
|
||||
# Run SWA KV
|
||||
print('Running SWA KV...', flush=True)
|
||||
o_unnorm_kernel_swa, lse_kernel_swa = run_fmha_unnorm(q, k_swa, v_swa, kernel, compiled, stream)
|
||||
o_kernel_swa, lse_kernel_swa = run_fmha(q, k_swa, v_swa, kernel, compiled, stream)
|
||||
|
||||
# Merge with sink weights (Python) — use stable normalized merge formula
|
||||
lse_comp_per_row = lse_comp[:, 0] # (m,) — reference per-row LSE
|
||||
lse_swa_per_row = lse_swa[:, 0] # (m,) — reference per-row LSE
|
||||
o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp_per_row.unsqueeze(1))
|
||||
o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa_per_row.unsqueeze(1))
|
||||
|
||||
# Stable merge (same as reference)
|
||||
lse_max_kern = torch.max(lse_comp_per_row.unsqueeze(1), lse_swa_per_row.unsqueeze(1))
|
||||
exp_lse_comp_kern = torch.exp(lse_comp_per_row.unsqueeze(1) - lse_max_kern)
|
||||
exp_lse_swa_kern = torch.exp(lse_swa_per_row.unsqueeze(1) - lse_max_kern)
|
||||
# Merge with sink weights using standard formula:
|
||||
# numerator = exp(lse1) * O1_norm + exp(sink) * exp(lse2) * O2_norm
|
||||
# denominator = exp(lse1) + exp(sink) * exp(lse2)
|
||||
lse_comp_val = torch.tensor(lse_kernel_comp, dtype=torch.float32, device='cuda')
|
||||
lse_swa_val = torch.tensor(lse_kernel_swa, dtype=torch.float32, device='cuda')
|
||||
exp_lse_kern_comp = torch.exp(lse_comp_val)
|
||||
exp_lse_kern_swa = torch.exp(lse_swa_val)
|
||||
exp_sink_kern = torch.exp(attn_sink[0])
|
||||
|
||||
kern_numerator = exp_lse_comp_kern * o_norm_kernel_comp + exp_sink_kern * exp_lse_swa_kern * o_norm_kernel_swa
|
||||
kern_denominator = (exp_lse_comp_kern + exp_sink_kern * exp_lse_swa_kern).clamp(min=1e-30)
|
||||
# Using kernel's scalar LSE (row 0 only) for all rows
|
||||
kern_numerator = exp_lse_kern_comp * o_kernel_comp.float() + exp_sink_kern * exp_lse_kern_swa * o_kernel_swa.float()
|
||||
kern_denominator = (exp_lse_kern_comp + exp_sink_kern * exp_lse_kern_swa).clamp(min=1e-30)
|
||||
kern_output = kern_numerator / kern_denominator
|
||||
|
||||
# Compare with reference
|
||||
@@ -184,14 +182,14 @@ def test():
|
||||
print(f' kern[0,:4]={kern_output[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref_merge[0,:4].tolist()}')
|
||||
|
||||
# Also check individual attention passes
|
||||
# Also check individual attention passes (normalized O)
|
||||
cos_comp = torch.nn.functional.cosine_similarity(
|
||||
o_unnorm_kernel_comp.flatten().unsqueeze(0).float(),
|
||||
o_unnorm_comp.flatten().unsqueeze(0)
|
||||
o_kernel_comp.flatten().unsqueeze(0).float(),
|
||||
o_norm_comp.flatten().unsqueeze(0)
|
||||
).item()
|
||||
cos_swa = torch.nn.functional.cosine_similarity(
|
||||
o_unnorm_kernel_swa.flatten().unsqueeze(0).float(),
|
||||
o_unnorm_swa.flatten().unsqueeze(0)
|
||||
o_kernel_swa.flatten().unsqueeze(0).float(),
|
||||
o_norm_swa.flatten().unsqueeze(0)
|
||||
).item()
|
||||
print(f' Compressed KV unnorm cos: {cos_comp:.6f}')
|
||||
print(f' SWA KV unnorm cos: {cos_swa:.6f}')
|
||||
|
||||
Reference in New Issue
Block a user