Files
nvfp4-megamoe-kernel/tests/unit/test_d5c_multitile.py

167 lines
7.7 KiB
Python

"""
FMHA D5c: Sink bias + Python KV merge for multi-tile (s_k > 128).
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_multitile.py
"""
import torch
import math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def to_cute(t):
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
attn_sink, scale, swa_len, is_causal=False):
m, hd = q.shape
n_comp = k_comp.shape[0]
n_swa = k_swa.shape[0]
k_comb = torch.cat([k_comp, k_swa], dim=0)
v_comb = torch.cat([v_comp, v_swa], dim=0)
scores = q.float() @ k_comb.float().T * scale
scores[:, n_comp:] += attn_sink
if swa_len < n_swa:
scores[:, n_comp + swa_len:] = float('-inf')
if is_causal:
for i in range(m):
for j in range(n_swa):
if j > i:
scores[i, n_comp + j] = float('-inf')
max_s = scores.max(dim=-1, keepdim=True).values
exp_s = (scores - max_s).exp()
sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30)
return (exp_s / sum_s @ v_comb.float()).to(torch.bfloat16)
def python_kv_merge(segment_results):
"""O = sum_i(exp(lse_i) * O_i_norm) / sum_i(exp(lse_i))"""
lse_stack = torch.stack([r[1] for r in segment_results], dim=0)
lse_max = lse_stack.max(dim=0).values
numerator = torch.zeros_like(segment_results[0][0])
denominator = torch.zeros(lse_max.shape[0], dtype=torch.float32, device=lse_max.device)
for o_norm, lse in segment_results:
exp_lse = (lse - lse_max).exp()
numerator += exp_lse.unsqueeze(1) * o_norm.float()
denominator += exp_lse
return numerator / denominator.unsqueeze(1).clamp(min=1e-30)
def run_fmha(q, k, v, kernel_obj, compiled, stream, swa_len, sink_bias=None):
"""Run FMHA with single KV tile, return (O_norm, LSE)."""
m = q.shape[0]
hd = v.shape[1]
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
mQ = to_cute(q.unsqueeze(-1))
mK = to_cute(k.unsqueeze(-1))
mV = to_cute(v.unsqueeze(-1))
mC = to_cute(c); mLSE = to_cute(lse); mRS = to_cute(rs)
if sink_bias is not None:
mSB = to_cute(sink_bias)
compiled(mQ, mK, mV, mC, stream, mLSE, swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
else:
compiled(mQ, mK, mV, mC, stream, mLSE, swa_len=swa_len, row_sums=mRS)
torch.cuda.synchronize()
o_norm = c[:, :, 0].float() / rs[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
return o_norm, lse[:, 0, 0].clone()
def test_d5c_multitile():
print("=== D5c Multi-Tile: Sink Bias + Python KV Merge ===\n")
hd = 64; m = 128; n_comp = 96; n_swa = 160; n_total = 256
swa_len = 100; scale = 1.0 / math.sqrt(hd); seg_size = 128
torch.manual_seed(42)
q = torch.randn(m, hd, dtype=torch.bfloat16, device='cuda')
k_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
k_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
attn_sink_val = 0.5
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
k_combined = torch.cat([k_comp, k_swa], dim=0)
v_combined = torch.cat([v_comp, v_swa], dim=0)
# Full reference
ref = reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
attn_sink_val, scale, swa_len)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Per-segment reference and kernel verification
# Segment 0: positions 0-127 (n_comp=96, 32 SWA tokens, all valid, sink bias on SWA)
k0 = k_combined[0:128]; v0 = v_combined[0:128]
scores0 = q.float() @ k0.float().T * scale
scores0[:, n_comp:] += attn_sink_val
# swa_len=100 but only 32 SWA positions → no D3 masking
ref0 = (torch.softmax(scores0, dim=-1) @ v0.float()).to(torch.bfloat16)
lse0_ref = torch.logsumexp(scores0, dim=-1)
k_seg0 = FmhaKernel(head_dim=hd, s_k=128, normalize=False,
apply_swa_mask=True, is_causal=False, n_comp=n_comp, apply_sink_bias=True)
c0 = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
l0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
r0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
comp0 = cute.compile(k_seg0, to_cute(q.unsqueeze(-1)), to_cute(k0.unsqueeze(-1)),
to_cute(v0.unsqueeze(-1)), to_cute(c0), stream, to_cute(l0),
swa_len=100, sink_bias=to_cute(attn_sink), row_sums=to_cute(r0))
comp0(to_cute(q.unsqueeze(-1)), to_cute(k0.unsqueeze(-1)),
to_cute(v0.unsqueeze(-1)), to_cute(c0), stream, to_cute(l0),
swa_len=100, sink_bias=to_cute(attn_sink), row_sums=to_cute(r0))
torch.cuda.synchronize()
ok0 = c0[:, :, 0].float() / r0[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
cos0 = torch.nn.functional.cosine_similarity(ok0.flatten().unsqueeze(0), ref0.flatten().unsqueeze(0).float()).item()
lse0_kern = l0[:, 0, 0]
print(f'Seg0: cos {cos0:.6f} LSE_kern[0]={lse0_kern[0].item():.4f} LSE_ref[0]={lse0_ref[0].item():.4f}')
# Segment 1: positions 128-255 (all SWA, n_comp=0, sink bias on all, D3 mask at position 68)
k1 = k_combined[128:256]; v1 = v_combined[128:256]
scores1 = q.float() @ k1.float().T * scale
scores1 += attn_sink_val # all SWA → sink bias on all
# Valid SWA: absolute 96-195. In this segment (128-255): positions 0-67 valid, 68-127 masked
swa_len_seg1 = 68 # n_comp + swa_len - k_start = 96 + 100 - 128 = 68
scores1[:, swa_len_seg1:] = float('-inf')
ref1 = (torch.softmax(scores1, dim=-1) @ v1.float()).to(torch.bfloat16)
lse1_ref = torch.logsumexp(scores1, dim=-1)
k_seg1 = FmhaKernel(head_dim=hd, s_k=128, normalize=False,
apply_swa_mask=True, is_causal=False, n_comp=0, apply_sink_bias=True)
c1 = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
l1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
r1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
comp1 = cute.compile(k_seg1, to_cute(q.unsqueeze(-1)), to_cute(k1.unsqueeze(-1)),
to_cute(v1.unsqueeze(-1)), to_cute(c1), stream, to_cute(l1),
swa_len=swa_len_seg1, sink_bias=to_cute(attn_sink), row_sums=to_cute(r1))
comp1(to_cute(q.unsqueeze(-1)), to_cute(k1.unsqueeze(-1)),
to_cute(v1.unsqueeze(-1)), to_cute(c1), stream, to_cute(l1),
swa_len=swa_len_seg1, sink_bias=to_cute(attn_sink), row_sums=to_cute(r1))
torch.cuda.synchronize()
ok1 = c1[:, :, 0].float() / r1[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
cos1 = torch.nn.functional.cosine_similarity(ok1.flatten().unsqueeze(0), ref1.flatten().unsqueeze(0).float()).item()
lse1_kern = l1[:, 0, 0]
print(f'Seg1: cos {cos1:.6f} LSE_kern[0]={lse1_kern[0].item():.4f} LSE_ref[0]={lse1_ref[0].item():.4f}')
# Merge
o_merged = python_kv_merge([(ok0, lse0_kern), (ok1, lse1_kern)])
cos = torch.nn.functional.cosine_similarity(
o_merged.flatten().unsqueeze(0), ref.flatten().unsqueeze(0).float()).item()
status = "PASS" if cos >= 0.99 else "FAIL"
print(f'\nMerged: cos {cos:.6f} {status}')
if cos < 0.99:
print(f' kernel[0,:4]={o_merged[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
if __name__ == '__main__':
test_d5c_multitile()