D5c: multi-tile test using Python KV merge with sink bias

This commit is contained in:
2026-05-26 15:20:45 +00:00
parent 57a8316bc1
commit 3abcc7ff09

View File

@@ -0,0 +1,238 @@
"""
FMHA D5c: Sink bias + Python KV merge for multi-tile (s_k > 128).
Verifies the full DSV4 attention pipeline:
1. Concatenate KV: [compressed_K; swa_K] (total s_k > 128)
2. Split into 128-token segments
3. Run FMHA per segment (with sink bias on SWA positions, D3/D4 masking)
4. Merge segments using Python KV merge formula:
O = sum_i(exp(lse_i) * O_i_norm) / sum_i(exp(lse_i))
5. Normalize using row_sum
This is the production path for DSV4 Pro (s_k=1152, 9 KV tiles)
until the D1.5 TMEM round-trip fix enables in-kernel O rescale.
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_multitile.py
"""
import torch
import math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
attn_sink, scale, swa_len, is_causal=False):
"""FP32 reference: single softmax over combined KV with sink bias on SWA."""
m, hd = q.shape
n_comp = k_comp.shape[0]
n_swa = k_swa.shape[0]
k_combined = torch.cat([k_comp, k_swa], dim=0)
v_combined = torch.cat([v_comp, v_swa], dim=0)
scores = torch.matmul(q.float(), k_combined.float().T) * scale
scores[:, n_comp:] += attn_sink
if swa_len < n_swa:
scores[:, n_comp + swa_len:] = float('-inf')
if is_causal:
for i in range(m):
for j in range(n_swa):
if j > i:
scores[i, n_comp + j] = float('-inf')
max_s = scores.max(dim=-1, keepdim=True).values
exp_s = (scores - max_s).exp()
sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30)
o = torch.matmul(exp_s / sum_s, v_combined.float())
return o.to(torch.bfloat16)
def run_segment(q, k_seg, v_seg, kernel, compiled, stream,
sink_bias=None, n_comp_in_seg=0, swa_len=999999):
"""Run one 128-token segment of FMHA, return normalized O and LSE."""
m = q.shape[0]
hd = v_seg.shape[1]
pv_n_tile = kernel.pv_n_tile
# Allocate per-segment outputs
o_seg = torch.zeros(m, hd, dtype=torch.bfloat16, device='cuda')
lse_seg = torch.zeros(m, dtype=torch.float32, device='cuda')
def to_cute(t):
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
for nt in range(kernel.n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v_seg[:, v_start:v_end].contiguous()
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tile = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs_tile = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
mQ = to_cute(q.unsqueeze(-1))
mK = to_cute(k_seg.unsqueeze(-1))
mV = to_cute(v_tile.unsqueeze(-1))
mC = to_cute(c_tile)
mLSE = to_cute(lse_tile)
mRS = to_cute(rs_tile)
if sink_bias is not None:
mSB = to_cute(sink_bias)
compiled(mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
else:
compiled(mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, row_sums=mRS)
torch.cuda.synchronize()
o_seg[:, v_start:v_end] = c_tile[:, :, 0]
if nt == 0:
lse_seg = lse_tile[:, 0, 0].clone()
rs_seg = rs_tile[:, 0, 0].clone()
# Note: LSE and row_sum are the same across PV tiles (same softmax)
# Normalize using row_sum
o_norm = o_seg.float() / rs_seg.unsqueeze(1).clamp(min=1e-30)
return o_norm, lse_seg
def python_kv_merge(segment_results):
"""Merge multiple FMHA segments using Python KV merge formula.
O = sum_i(exp(lse_i) * O_i_norm) / sum_i(exp(lse_i))
"""
# Compute max LSE for numerical stability
lse_stack = torch.stack([r[1] for r in segment_results], dim=0) # (n_seg, m)
lse_max = lse_stack.max(dim=0).values # (m,)
numerator = torch.zeros_like(segment_results[0][0]) # (m, hd)
denominator = torch.zeros(lse_max.shape[0], dtype=torch.float32, device=lse_max.device)
for o_norm, lse in segment_results:
exp_lse = (lse - lse_max).exp() # (m,)
numerator += exp_lse.unsqueeze(1) * o_norm.float()
denominator += exp_lse
o_merged = numerator / denominator.unsqueeze(1).clamp(min=1e-30)
return o_merged
def test_d5c_multitile():
print("=== D5c Multi-Tile: Sink Bias + Python KV Merge ===\n")
hd = 64
m = 128
n_comp = 96 # compressed KV tokens
n_swa = 160 # SWA tokens
n_total = n_comp + n_swa # 256, 2 KV tiles
swa_len = 100 # valid SWA fill (within the 160 SWA window)
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
q = torch.randn(m, hd, dtype=torch.bfloat16, device='cuda')
k_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
k_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
attn_sink_val = 0.5
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
# Reference
ref = reference_combined_attention(
q, k_comp, v_comp, k_swa, v_swa,
attn_sink_val, scale, swa_len
)
# Combined KV
k_combined = torch.cat([k_comp, k_swa], dim=0) # (256, hd)
v_combined = torch.cat([v_comp, v_swa], dim=0) # (256, hd)
# Split into 128-token segments and run FMHA per segment
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
seg_size = 128
n_segs = (n_total + seg_size - 1) // seg_size
# Compile kernel for s_k=128 (single KV tile, no O rescale)
kernel = FmhaKernel(head_dim=hd, s_k=seg_size, normalize=False,
apply_swa_mask=True, is_causal=False, n_comp=n_comp)
# Pre-compile with dummy data
def to_cute(t):
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
_q = q.unsqueeze(-1)
_k = k_combined[:seg_size].unsqueeze(-1)
_v = v_combined[:seg_size, :kernel.pv_n_tile].contiguous().unsqueeze(-1)
_c = torch.zeros(m, kernel.pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
_lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
_rs = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
_mQ = to_cute(_q); _mK = to_cute(_k); _mV = to_cute(_v)
_mC = to_cute(_c); _mLSE = to_cute(_lse); _mRS = to_cute(_rs)
_mSB = to_cute(attn_sink)
compiled = cute.compile(kernel, _mQ, _mK, _mV, _mC, stream, _mLSE,
swa_len=swa_len, sink_bias=_mSB, row_sums=_mRS)
# Run each segment
segment_results = []
for seg_idx in range(n_segs):
k_start = seg_idx * seg_size
k_end = min(k_start + seg_size, n_total)
k_seg = k_combined[k_start:k_end]
v_seg = v_combined[k_start:k_end]
# Determine sink bias and swa_len for this segment
# Sink bias applies to positions >= n_comp (absolute)
# In this segment, the first local position maps to absolute position k_start
# So sink bias applies to local positions >= max(0, n_comp - k_start)
n_comp_local = max(0, n_comp - k_start) # compressed tokens in this segment
# swa_len: how many SWA positions are valid in this segment
# Absolute valid range: n_comp to n_comp + swa_len - 1
swa_start_abs = n_comp
swa_end_abs = n_comp + swa_len
# This segment covers absolute positions k_start to k_end - 1
# Valid SWA in this segment: max(swa_start_abs, k_start) to min(swa_end_abs, k_end) - 1
# swa_len_local: number of valid SWA positions in this segment
valid_start = max(swa_start_abs, k_start)
valid_end = min(swa_end_abs, k_end)
if valid_end > valid_start and n_comp_local > 0:
# SWA positions in this segment: n_comp_local to n_comp_local + (valid_end - valid_start) - 1
swa_len_local = n_comp_local + (valid_end - valid_start)
elif valid_end <= valid_start:
swa_len_local = 0 # No valid SWA in this segment
else:
swa_len_local = swa_len # All SWA is valid
# If no compressed tokens and no SWA tokens in this segment, skip
if k_seg.shape[0] == 0:
continue
# For segments that are entirely compressed (no SWA), don't apply sink bias
has_swa = (k_start + seg_size) > n_comp
o_norm, lse = run_segment(
q, k_seg, v_seg, kernel, compiled, stream,
sink_bias=attn_sink if has_swa else None,
n_comp_in_seg=n_comp_local,
swa_len=swa_len_local
)
segment_results.append((o_norm, lse))
# Merge segments
o_merged = python_kv_merge(segment_results)
# Compare
cos = torch.nn.functional.cosine_similarity(
o_merged.flatten().unsqueeze(0),
ref.flatten().unsqueeze(0).float()
).item()
max_abs = (o_merged - ref.float()).abs().max().item()
status = "PASS" if cos >= 0.99 else "FAIL"
print(f'D5c multi-tile: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
if cos < 0.99:
print(f' kernel[0,:4]={o_merged[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
if __name__ == '__main__':
test_d5c_multitile()