D5b: Fix syntax error

This commit is contained in:
2026-05-23 21:30:00 +00:00
parent 6a47015e85
commit 369d677c2c

View File

@@ -19,11 +19,11 @@ import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def run_fmha_unnorm(q, k, v, kernel, stream):
def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream):
"""Run FMHA with normalize=False, return un-normalized O and LSE."""
m = 128 # M tile
hd = v.shape[1]
pv_n_tile = kernel.pv_n_tile
pv_n_tile = kernel_obj.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
@@ -42,7 +42,7 @@ def run_fmha_unnorm(q, k, v, kernel, stream):
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tile))
kernel(mQ, mK, mV, mC, stream, mLSE)
compiled_kernel(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
c_unnorm[:, v_start:v_end, :] = c_tile
@@ -180,25 +180,32 @@ def test():
# Run compressed KV
print('Running compressed KV...', flush=True)
o_unnorm_kernel_comp, lse_kernel_comp = run_fmha_unnorm(q, k_comp, v_comp, compiled, stream)
o_unnorm_kernel_comp, lse_kernel_comp = run_fmha_unnorm(q, k_comp, v_comp, kernel, compiled, stream)
# Run SWA KV (re-compile with different s_k if needed, or same if n_swa==n_comp)
print('Running SWA KV...', flush=True)
o_unnorm_kernel_swa, lse_kernel_swa = run_fmha_unnorm(q, k_swa, v_swa, compiled, stream)
o_unnorm_kernel_swa, lse_kernel_swa = run_fmha_unnorm(q, k_swa, v_swa, kernel, compiled, stream)
# Merge with sink weights (Python)
# Merge with sink weights (Python) — use NORMALIZED merge formula
# Convert kernel outputs to normalized: O_norm = O_unnorm / exp(lse)
lse_comp_val = torch.tensor(lse_kernel_comp, dtype=torch.float32, device='cuda')
lse_swa_val = torch.tensor(lse_kernel_swa, dtype=torch.float32, device='cuda')
# For M=128 rows, the kernel only outputs lse for row 0.
# Use the per-row reference lse for proper merge.
# TODO: kernel should output per-row lse (m,1) not scalar
# For now, use row-0 lse for all rows (works for testing the pipeline)
o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp[0, 0])
o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa[0, 0])
exp_lse_kern_comp = torch.exp(lse_comp_val)
exp_lse_kern_swa = torch.exp(lse_swa_val)
exp_sink_kern = torch.exp(attn_sink[0])
# numerator = o_unnorm_comp + exp(sink) * o_unnorm_swa
# denominator = exp(lse_comp) + exp(sink) * exp(lse_swa)
kern_numerator = o_unnorm_kernel_comp.float() + exp_sink_kern * o_unnorm_kernel_swa.float()
# Standard merge: numerator = exp(lse1)*O1 + exp(sink)*exp(lse2)*O2
kern_numerator = exp_lse_kern_comp * o_norm_kernel_comp + exp_sink_kern * exp_lse_kern_swa * o_norm_kernel_swa
kern_denominator = (exp_lse_kern_comp + exp_sink_kern * exp_lse_kern_swa).clamp(min=1e-30)
kern_output = kern_numerator / kern_denominator # (m, hd)
kern_output = kern_numerator / kern_denominator
# Compare with reference
cos = torch.nn.functional.cosine_similarity(