D5b: Python SWA+sink merge test
- Run FMHA twice (compressed KV + SWA KV, normalize=False) - Merge with sink weights in Python - Verify end-to-end correctness vs FP32 reference
This commit is contained in:
195
tests/unit/test_fmha_v3_stage_d5b.py
Normal file
195
tests/unit/test_fmha_v3_stage_d5b.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
FMHA v3 Stage D5b: SWA + Sink Merge (Python-level).
|
||||
|
||||
Tests the full DSV4 attention pipeline:
|
||||
1. Run FMHA with compressed KV (normalize=False) → o_unnorm_sparse, lse_sparse
|
||||
2. Run FMHA with SWA KV (normalize=False) → o_unnorm_swa, lse_swa
|
||||
3. Merge with sink weights in Python:
|
||||
numerator = o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa
|
||||
denominator = exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)
|
||||
output = numerator / denominator
|
||||
|
||||
This is the D5b milestone: end-to-end correctness with SWA + sink merge.
|
||||
Uses hd=64 TMEM-P path (SMEM-P not needed for this test).
|
||||
"""
|
||||
import torch, math
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def run_fmha_unnorm(q, k, v, kernel, stream):
|
||||
"""Run FMHA with normalize=False, return un-normalized O and LSE."""
|
||||
m = 128 # M tile
|
||||
hd = v.shape[1]
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
n_pv_tiles = kernel.n_pv_tiles
|
||||
|
||||
c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
for nt in range(n_pv_tiles):
|
||||
v_start = nt * pv_n_tile
|
||||
v_end = v_start + pv_n_tile
|
||||
v_tile = v[:, v_start:v_end].contiguous().unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tile = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tile))
|
||||
|
||||
kernel(mQ, mK, mV, mC, stream, mLSE)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
c_unnorm[:, v_start:v_end, :] = c_tile
|
||||
if nt == 0:
|
||||
lse_tensor = lse_tile
|
||||
|
||||
o_unnorm = c_unnorm[:, :, 0] # (m, hd)
|
||||
lse = lse_tensor[0, 0, 0].item() # scalar (M=1 decode)
|
||||
return o_unnorm, lse
|
||||
|
||||
|
||||
def test():
|
||||
print("=== Stage D5b: SWA + Sink Merge (Python) ===\n")
|
||||
|
||||
hd = 64
|
||||
m = 128
|
||||
n_comp = 128 # compressed KV length
|
||||
n_swa = 128 # SWA KV length
|
||||
torch.manual_seed(42)
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Per-head sink weight (learnable parameter)
|
||||
attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda') # (1,) for 1 head
|
||||
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
# === FP32 Reference: Full attention with sink merge ===
|
||||
qf = q[:, :, 0].float() # (m, hd)
|
||||
kf_comp = k_comp[:, :, 0].float()
|
||||
vf_comp = v_comp.float()
|
||||
kf_swa = k_swa[:, :, 0].float()
|
||||
vf_swa = v_swa.float()
|
||||
|
||||
# Compressed KV attention
|
||||
attn_comp = qf @ kf_comp.T * scale # (m, n_comp)
|
||||
attn_comp_max = attn_comp.max(dim=-1, keepdim=True)[0]
|
||||
attn_comp_exp = torch.exp(attn_comp - attn_comp_max)
|
||||
attn_comp_sum = attn_comp_exp.sum(dim=-1, keepdim=True)
|
||||
lse_comp = torch.log(attn_comp_sum) + attn_comp_max # (m, 1)
|
||||
o_unnorm_comp = attn_comp_exp @ vf_comp # (m, hd) un-normalized
|
||||
o_norm_comp = o_unnorm_comp / attn_comp_sum # normalized
|
||||
|
||||
# SWA KV attention
|
||||
attn_swa = qf @ kf_swa.T * scale
|
||||
attn_swa_max = attn_swa.max(dim=-1, keepdim=True)[0]
|
||||
attn_swa_exp = torch.exp(attn_swa - attn_swa_max)
|
||||
attn_swa_sum = attn_swa_exp.sum(dim=-1, keepdim=True)
|
||||
lse_swa = torch.log(attn_swa_sum) + attn_swa_max # (m, 1)
|
||||
o_unnorm_swa = attn_swa_exp @ vf_swa # un-normalized
|
||||
o_norm_swa = o_unnorm_swa / attn_swa_sum # normalized
|
||||
|
||||
# Sink weight merge (reference formula from decode_sparse.py)
|
||||
# numerator = exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa
|
||||
# denominator = exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)
|
||||
exp_lse_comp = lse_comp.exp() # (m, 1)
|
||||
exp_lse_swa = lse_swa.exp() # (m, 1)
|
||||
exp_sink = attn_sink.exp() # (1,)
|
||||
|
||||
numerator = (exp_lse_comp * o_norm_comp + exp_sink * exp_lse_swa * o_norm_swa)
|
||||
denominator = (exp_lse_comp + exp_sink * exp_lse_swa).clamp(min=1e-30)
|
||||
ref_output = numerator / denominator # (m, hd)
|
||||
|
||||
# Un-normalized version (for kernel output):
|
||||
# numerator = o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa
|
||||
# denominator = exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)
|
||||
numerator_unnorm = o_unnorm_comp + exp_sink * o_unnorm_swa
|
||||
denominator_unnorm = (exp_lse_comp + exp_sink * exp_lse_swa).clamp(min=1e-30)
|
||||
ref_output_unnorm = numerator_unnorm / denominator_unnorm
|
||||
|
||||
# Verify both formulas give the same result
|
||||
unnorm_vs_norm_cos = torch.nn.functional.cosine_similarity(
|
||||
ref_output.flatten().unsqueeze(0),
|
||||
ref_output_unnorm.flatten().unsqueeze(0)
|
||||
).item()
|
||||
print(f"Reference formula check: normalized vs unnorm cos = {unnorm_vs_norm_cos:.6f}")
|
||||
assert unnorm_vs_norm_cos > 0.999, f"Reference formulas don't match: cos={unnorm_vs_norm_cos}"
|
||||
|
||||
# === Kernel: Run FMHA twice (normalize=False) and merge ===
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=n_comp, normalize=False)
|
||||
|
||||
# Compile
|
||||
print('Compiling kernel...', flush=True)
|
||||
v_tile = v_comp[:, 0:kernel.pv_n_tile].contiguous().unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, kernel.pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tile = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k_comp).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_comp))
|
||||
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tile))
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
|
||||
|
||||
# Run compressed KV
|
||||
print('Running compressed KV...', flush=True)
|
||||
o_unnorm_kernel_comp, lse_kernel_comp = run_fmha_unnorm(q, k_comp, v_comp, compiled, stream)
|
||||
|
||||
# Run SWA KV (re-compile with different s_k if needed, or same if n_swa==n_comp)
|
||||
print('Running SWA KV...', flush=True)
|
||||
o_unnorm_kernel_swa, lse_kernel_swa = run_fmha_unnorm(q, k_swa, v_swa, compiled, stream)
|
||||
|
||||
# Merge with sink weights (Python)
|
||||
lse_comp_val = torch.tensor(lse_kernel_comp, dtype=torch.float32, device='cuda')
|
||||
lse_swa_val = torch.tensor(lse_kernel_swa, dtype=torch.float32, device='cuda')
|
||||
|
||||
exp_lse_kern_comp = torch.exp(lse_comp_val)
|
||||
exp_lse_kern_swa = torch.exp(lse_swa_val)
|
||||
exp_sink_kern = torch.exp(attn_sink[0])
|
||||
|
||||
# numerator = o_unnorm_comp + exp(sink) * o_unnorm_swa
|
||||
# denominator = exp(lse_comp) + exp(sink) * exp(lse_swa)
|
||||
kern_numerator = o_unnorm_kernel_comp.float() + exp_sink_kern * o_unnorm_kernel_swa.float()
|
||||
kern_denominator = (exp_lse_kern_comp + exp_sink_kern * exp_lse_kern_swa).clamp(min=1e-30)
|
||||
kern_output = kern_numerator / kern_denominator # (m, hd)
|
||||
|
||||
# Compare with reference
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
kern_output.flatten().unsqueeze(0),
|
||||
ref_output_unnorm.flatten().unsqueeze(0)
|
||||
).item()
|
||||
max_abs = (kern_output - ref_output_unnorm).abs().max().item()
|
||||
|
||||
status = "PASS" if cos >= 0.95 else "FAIL"
|
||||
print(f'\nMerge result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
|
||||
if cos < 0.95:
|
||||
print(f' kern[0,:4]={kern_output[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref_output_unnorm[0,:4].tolist()}')
|
||||
|
||||
# Also check individual attention passes
|
||||
cos_comp = torch.nn.functional.cosine_similarity(
|
||||
o_unnorm_kernel_comp.flatten().unsqueeze(0).float(),
|
||||
o_unnorm_comp.flatten().unsqueeze(0)
|
||||
).item()
|
||||
cos_swa = torch.nn.functional.cosine_similarity(
|
||||
o_unnorm_kernel_swa.flatten().unsqueeze(0).float(),
|
||||
o_unnorm_swa.flatten().unsqueeze(0)
|
||||
).item()
|
||||
print(f' Compressed KV unnorm cos: {cos_comp:.6f}')
|
||||
print(f' SWA KV unnorm cos: {cos_swa:.6f}')
|
||||
print(f' LSE comp: kernel={lse_kernel_comp:.6f} ref={lse_comp[0,0].item():.6f}')
|
||||
print(f' LSE swa: kernel={lse_kernel_swa:.6f} ref={lse_swa[0,0].item():.6f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user