Files
nvfp4-megamoe-kernel/tests/unit/test_d5c_fused.py
biondizzle e64392f1ac D5c: add apply_sink_bias flag (independent of n_comp)
For all-SWA segments (n_comp=0), sink bias still needs to be applied
to all positions. The apply_sink_bias flag controls compilation of
the sink bias code path, independent of n_comp offset.
2026-05-26 15:26:52 +00:00

305 lines
11 KiB
Python

"""
FMHA D5c: Fused sparse + SWA attention via combined KV + sink bias.
Mathematical insight: the sink merge is equivalent to a single attention
pass over the concatenated KV with a logit bias (attn_sink) applied to
the SWA portion. No two passes needed, no merge epilogue.
S = [S_comp, S_swa + attn_sink]
O = softmax(S) @ [V_comp; V_swa]
This is identical to:
O = (exp(lse_sparse)*O_sparse + exp(sink)*exp(lse_swa)*O_swa)
/ (exp(lse_sparse) + exp(sink)*exp(lse_swa))
The kernel changes are minimal:
1. K = [compressed_K; swa_K], V = [compressed_V; swa_V]
2. n_comp = length of compressed KV (sink bias applies to positions >= n_comp)
3. attn_sink = per-head logit bias for SWA positions
4. D3/D4 masking applies to SWA region (positions >= n_comp)
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_fused.py
"""
import torch
import math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
attn_sink, scale, swa_len, is_causal=False):
"""FP32 reference: single softmax over combined KV with sink bias on SWA."""
m, hd = q.shape
n_comp = k_comp.shape[0]
n_swa = k_swa.shape[0]
# Concatenate KV
k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_comp + n_swa, hd)
v_combined = torch.cat([v_comp, v_swa], dim=0)
# Compute combined logits
scores = torch.matmul(q.float(), k_combined.float().T) * scale # (m, n_comp + n_swa)
# Add sink bias to SWA positions
scores[:, n_comp:] += attn_sink
# D3: SWA length mask (only SWA region)
if swa_len < n_swa:
scores[:, n_comp + swa_len:] = float('-inf')
# D4: causal mask (only SWA region)
if is_causal:
# Within SWA region: mask k_coord > m_coord
for i in range(m):
for j in range(n_swa):
if j > i: # k_coord > m_coord
scores[i, n_comp + j] = float('-inf')
# Softmax + PV
max_s = scores.max(dim=-1, keepdim=True).values
exp_s = (scores - max_s).exp()
sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30)
p = exp_s / sum_s
o = torch.matmul(p, v_combined.float())
return o.to(torch.bfloat16)
def reference_sink_merge(q, k_comp, v_comp, k_swa, v_swa,
attn_sink, scale, swa_len, is_causal=False):
"""FP32 reference: separate attention + sink merge (original D5b formula)."""
m, hd = q.shape
n_comp = k_comp.shape[0]
n_swa = k_swa.shape[0]
# Compressed KV attention (no mask)
attn_comp = torch.matmul(q.float(), k_comp.float().T) * scale
o_norm_comp = torch.softmax(attn_comp, dim=-1) @ v_comp.float()
lse_comp = torch.logsumexp(attn_comp, dim=-1, keepdim=True) # (m, 1)
# SWA KV attention (with swa_len mask)
attn_swa = torch.matmul(q.float(), k_swa.float().T) * scale
if swa_len < n_swa:
attn_swa[:, swa_len:] = float('-inf')
if is_causal:
for i in range(m):
for j in range(n_swa):
if j > i:
attn_swa[i, j] = float('-inf')
o_norm_swa = torch.softmax(attn_swa, dim=-1) @ v_swa.float()
lse_swa = torch.logsumexp(attn_swa, dim=-1, keepdim=True)
# Sink merge (normalized formula)
exp_sink = math.exp(attn_sink)
numerator = lse_comp.exp() * o_norm_comp + exp_sink * lse_swa.exp() * o_norm_swa
denominator = (lse_comp.exp() + exp_sink * lse_swa.exp()).clamp(min=1e-30)
o = numerator / denominator
return o.to(torch.bfloat16)
def test_d5c_combined():
print("=== Stage D5c: Fused Sparse+SWA via Combined KV + Sink Bias ===\n")
hd = 64
m = 128 # query rows
n_comp = 64 # compressed KV length (fit in single 128-wide KV tile)
n_swa = 64 # SWA window length
n_total = n_comp + n_swa # combined KV length = 128 (1 KV tile)
swa_len = 40 # actual SWA fill
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
# Sink weight (in log domain, per-head). n_h=1 so it's a scalar.
attn_sink_val = 0.5
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
# === FP32 References ===
qf = q[:, :, 0]
# Reference 1: Combined softmax with sink bias (our kernel's approach)
ref_combined = reference_combined_attention(
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
attn_sink_val, scale, swa_len
)
# Reference 2: Separate attention + sink merge (original D5b formula)
ref_merge = reference_sink_merge(
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
attn_sink_val, scale, swa_len
)
# Verify the two references agree
cos_ref = torch.nn.functional.cosine_similarity(
ref_combined.flatten().unsqueeze(0).float(),
ref_merge.flatten().unsqueeze(0).float()
).item()
print(f"Reference: combined softmax vs sink merge cos = {cos_ref:.6f}")
assert cos_ref > 0.999, f"References don't match: cos={cos_ref}"
# === Kernel ===
# Concatenate KV for the kernel
k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_total, hd, 1)
v_combined = torch.cat([v_comp, v_swa], dim=0) # (n_total, hd)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(
head_dim=hd,
s_k=n_total, # combined KV length
normalize=False, # D5a: emit un-normalized O + LSE
apply_swa_mask=True, # D3: mask SWA positions
is_causal=False, # D4: no causal mask for this test
n_comp=n_comp, # D5c: compressed KV length (sink bias starts here)
apply_sink_bias=True, # D5c: enable sink bias logit modification
)
# Allocate output
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
row_sum_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
# Prepare CuTe tensors
def to_cute(t):
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
mQ = to_cute(q)
mK = to_cute(k_combined)
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
mC = to_cute(c_out)
mLSE = to_cute(lse_out)
mRowSums = to_cute(row_sum_out)
# Compile
print('Compiling D5c kernel (combined KV + sink bias)...', flush=True)
mSinkBias = to_cute(attn_sink)
compiled = cute.compile(
kernel, mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSinkBias, row_sums=mRowSums,
)
# Run
print('Running D5c kernel...', flush=True)
compiled(
mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSinkBias, row_sums=mRowSums,
)
torch.cuda.synchronize()
# Check results
# Kernel outputs UN-NORMALIZED O (normalize=False). Normalize using per-row row_sum.
# O_norm[i] = O_unnorm[i] / row_sum[i]
o_kernel_unnorm = c_out[:, :, 0].float() # (m, hd)
row_sums = row_sum_out[:, 0, 0].float() # (m,)
# Normalize each row by its row_sum
o_kernel = o_kernel_unnorm / row_sums.unsqueeze(1).clamp(min=1e-30)
cos = torch.nn.functional.cosine_similarity(
o_kernel.flatten().unsqueeze(0),
ref_combined.flatten().unsqueeze(0).float()
).item()
max_abs = (o_kernel - ref_combined.float()).abs().max().item()
status = "PASS" if cos >= 0.99 else "FAIL"
print(f'\nD5c result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
if cos < 0.99:
print(f' kernel[0,:4]={o_kernel[0,:4].tolist()}')
print(f' ref[0,:4]={ref_combined[0,:4].tolist()}')
print(f' row_sum range: {row_sums.min().item():.4f} to {row_sums.max().item():.4f}')
print(f' LSE range: {lse_out[:,0,0].min().item():.4f} to {lse_out[:,0,0].max().item():.4f}')
def test_d5c_with_causal():
"""D5c with causal mask on SWA branch."""
print("\n=== Stage D5c: Fused Sparse+SWA with Causal Mask ===\n")
hd = 64
m = 128
n_comp = 64
n_swa = 64
n_total = n_comp + n_swa # 128, single KV tile
swa_len = 48 # partially filled SWA
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(123)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
attn_sink_val = 0.3
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
qf = q[:, :, 0]
ref = reference_combined_attention(
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
attn_sink_val, scale, swa_len, is_causal=True
)
# Concatenate KV
k_combined = torch.cat([k_comp, k_swa], dim=0)
v_combined = torch.cat([v_comp, v_swa], dim=0)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(
head_dim=hd,
s_k=n_total,
normalize=False,
apply_swa_mask=True,
is_causal=True,
n_comp=n_comp,
apply_sink_bias=True,
)
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
row_sum_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
def to_cute(t):
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
mQ = to_cute(q)
mK = to_cute(k_combined)
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
mC = to_cute(c_out)
mLSE = to_cute(lse_out)
mRowSums = to_cute(row_sum_out)
mSinkBias = to_cute(attn_sink)
compiled = cute.compile(
kernel, mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSinkBias, row_sums=mRowSums,
)
compiled(
mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSinkBias, row_sums=mRowSums,
)
torch.cuda.synchronize()
o_kernel_unnorm = c_out[:, :, 0].float()
row_sums = row_sum_out[:, 0, 0].float()
o_kernel = o_kernel_unnorm / row_sums.unsqueeze(1).clamp(min=1e-30)
cos = torch.nn.functional.cosine_similarity(
o_kernel.flatten().unsqueeze(0),
ref.flatten().unsqueeze(0).float()
).item()
max_abs = (o_kernel - ref.float()).abs().max().item()
status = "PASS" if cos >= 0.99 else "FAIL"
print(f'D5c causal result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
if __name__ == '__main__':
test_d5c_combined()
test_d5c_with_causal()