D5c: multi-tile test using Python KV merge with sink bias
This commit is contained in:
238
tests/unit/test_d5c_multitile.py
Normal file
238
tests/unit/test_d5c_multitile.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
FMHA D5c: Sink bias + Python KV merge for multi-tile (s_k > 128).
|
||||
|
||||
Verifies the full DSV4 attention pipeline:
|
||||
1. Concatenate KV: [compressed_K; swa_K] (total s_k > 128)
|
||||
2. Split into 128-token segments
|
||||
3. Run FMHA per segment (with sink bias on SWA positions, D3/D4 masking)
|
||||
4. Merge segments using Python KV merge formula:
|
||||
O = sum_i(exp(lse_i) * O_i_norm) / sum_i(exp(lse_i))
|
||||
5. Normalize using row_sum
|
||||
|
||||
This is the production path for DSV4 Pro (s_k=1152, 9 KV tiles)
|
||||
until the D1.5 TMEM round-trip fix enables in-kernel O rescale.
|
||||
|
||||
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_multitile.py
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
|
||||
attn_sink, scale, swa_len, is_causal=False):
|
||||
"""FP32 reference: single softmax over combined KV with sink bias on SWA."""
|
||||
m, hd = q.shape
|
||||
n_comp = k_comp.shape[0]
|
||||
n_swa = k_swa.shape[0]
|
||||
k_combined = torch.cat([k_comp, k_swa], dim=0)
|
||||
v_combined = torch.cat([v_comp, v_swa], dim=0)
|
||||
scores = torch.matmul(q.float(), k_combined.float().T) * scale
|
||||
scores[:, n_comp:] += attn_sink
|
||||
if swa_len < n_swa:
|
||||
scores[:, n_comp + swa_len:] = float('-inf')
|
||||
if is_causal:
|
||||
for i in range(m):
|
||||
for j in range(n_swa):
|
||||
if j > i:
|
||||
scores[i, n_comp + j] = float('-inf')
|
||||
max_s = scores.max(dim=-1, keepdim=True).values
|
||||
exp_s = (scores - max_s).exp()
|
||||
sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30)
|
||||
o = torch.matmul(exp_s / sum_s, v_combined.float())
|
||||
return o.to(torch.bfloat16)
|
||||
|
||||
|
||||
def run_segment(q, k_seg, v_seg, kernel, compiled, stream,
|
||||
sink_bias=None, n_comp_in_seg=0, swa_len=999999):
|
||||
"""Run one 128-token segment of FMHA, return normalized O and LSE."""
|
||||
m = q.shape[0]
|
||||
hd = v_seg.shape[1]
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
|
||||
# Allocate per-segment outputs
|
||||
o_seg = torch.zeros(m, hd, dtype=torch.bfloat16, device='cuda')
|
||||
lse_seg = torch.zeros(m, dtype=torch.float32, device='cuda')
|
||||
|
||||
def to_cute(t):
|
||||
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
|
||||
|
||||
for nt in range(kernel.n_pv_tiles):
|
||||
v_start = nt * pv_n_tile
|
||||
v_end = v_start + pv_n_tile
|
||||
v_tile = v_seg[:, v_start:v_end].contiguous()
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tile = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
rs_tile = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
mQ = to_cute(q.unsqueeze(-1))
|
||||
mK = to_cute(k_seg.unsqueeze(-1))
|
||||
mV = to_cute(v_tile.unsqueeze(-1))
|
||||
mC = to_cute(c_tile)
|
||||
mLSE = to_cute(lse_tile)
|
||||
mRS = to_cute(rs_tile)
|
||||
|
||||
if sink_bias is not None:
|
||||
mSB = to_cute(sink_bias)
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE,
|
||||
swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
|
||||
else:
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE,
|
||||
swa_len=swa_len, row_sums=mRS)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
o_seg[:, v_start:v_end] = c_tile[:, :, 0]
|
||||
if nt == 0:
|
||||
lse_seg = lse_tile[:, 0, 0].clone()
|
||||
rs_seg = rs_tile[:, 0, 0].clone()
|
||||
# Note: LSE and row_sum are the same across PV tiles (same softmax)
|
||||
|
||||
# Normalize using row_sum
|
||||
o_norm = o_seg.float() / rs_seg.unsqueeze(1).clamp(min=1e-30)
|
||||
return o_norm, lse_seg
|
||||
|
||||
|
||||
def python_kv_merge(segment_results):
|
||||
"""Merge multiple FMHA segments using Python KV merge formula.
|
||||
|
||||
O = sum_i(exp(lse_i) * O_i_norm) / sum_i(exp(lse_i))
|
||||
"""
|
||||
# Compute max LSE for numerical stability
|
||||
lse_stack = torch.stack([r[1] for r in segment_results], dim=0) # (n_seg, m)
|
||||
lse_max = lse_stack.max(dim=0).values # (m,)
|
||||
|
||||
numerator = torch.zeros_like(segment_results[0][0]) # (m, hd)
|
||||
denominator = torch.zeros(lse_max.shape[0], dtype=torch.float32, device=lse_max.device)
|
||||
|
||||
for o_norm, lse in segment_results:
|
||||
exp_lse = (lse - lse_max).exp() # (m,)
|
||||
numerator += exp_lse.unsqueeze(1) * o_norm.float()
|
||||
denominator += exp_lse
|
||||
|
||||
o_merged = numerator / denominator.unsqueeze(1).clamp(min=1e-30)
|
||||
return o_merged
|
||||
|
||||
|
||||
def test_d5c_multitile():
|
||||
print("=== D5c Multi-Tile: Sink Bias + Python KV Merge ===\n")
|
||||
|
||||
hd = 64
|
||||
m = 128
|
||||
n_comp = 96 # compressed KV tokens
|
||||
n_swa = 160 # SWA tokens
|
||||
n_total = n_comp + n_swa # 256, 2 KV tiles
|
||||
swa_len = 100 # valid SWA fill (within the 160 SWA window)
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
torch.manual_seed(42)
|
||||
|
||||
q = torch.randn(m, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
attn_sink_val = 0.5
|
||||
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
|
||||
|
||||
# Reference
|
||||
ref = reference_combined_attention(
|
||||
q, k_comp, v_comp, k_swa, v_swa,
|
||||
attn_sink_val, scale, swa_len
|
||||
)
|
||||
|
||||
# Combined KV
|
||||
k_combined = torch.cat([k_comp, k_swa], dim=0) # (256, hd)
|
||||
v_combined = torch.cat([v_comp, v_swa], dim=0) # (256, hd)
|
||||
|
||||
# Split into 128-token segments and run FMHA per segment
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
seg_size = 128
|
||||
n_segs = (n_total + seg_size - 1) // seg_size
|
||||
|
||||
# Compile kernel for s_k=128 (single KV tile, no O rescale)
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=seg_size, normalize=False,
|
||||
apply_swa_mask=True, is_causal=False, n_comp=n_comp)
|
||||
|
||||
# Pre-compile with dummy data
|
||||
def to_cute(t):
|
||||
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
|
||||
_q = q.unsqueeze(-1)
|
||||
_k = k_combined[:seg_size].unsqueeze(-1)
|
||||
_v = v_combined[:seg_size, :kernel.pv_n_tile].contiguous().unsqueeze(-1)
|
||||
_c = torch.zeros(m, kernel.pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
_lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
_rs = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
_mQ = to_cute(_q); _mK = to_cute(_k); _mV = to_cute(_v)
|
||||
_mC = to_cute(_c); _mLSE = to_cute(_lse); _mRS = to_cute(_rs)
|
||||
_mSB = to_cute(attn_sink)
|
||||
compiled = cute.compile(kernel, _mQ, _mK, _mV, _mC, stream, _mLSE,
|
||||
swa_len=swa_len, sink_bias=_mSB, row_sums=_mRS)
|
||||
|
||||
# Run each segment
|
||||
segment_results = []
|
||||
for seg_idx in range(n_segs):
|
||||
k_start = seg_idx * seg_size
|
||||
k_end = min(k_start + seg_size, n_total)
|
||||
k_seg = k_combined[k_start:k_end]
|
||||
v_seg = v_combined[k_start:k_end]
|
||||
|
||||
# Determine sink bias and swa_len for this segment
|
||||
# Sink bias applies to positions >= n_comp (absolute)
|
||||
# In this segment, the first local position maps to absolute position k_start
|
||||
# So sink bias applies to local positions >= max(0, n_comp - k_start)
|
||||
n_comp_local = max(0, n_comp - k_start) # compressed tokens in this segment
|
||||
# swa_len: how many SWA positions are valid in this segment
|
||||
# Absolute valid range: n_comp to n_comp + swa_len - 1
|
||||
swa_start_abs = n_comp
|
||||
swa_end_abs = n_comp + swa_len
|
||||
# This segment covers absolute positions k_start to k_end - 1
|
||||
# Valid SWA in this segment: max(swa_start_abs, k_start) to min(swa_end_abs, k_end) - 1
|
||||
# swa_len_local: number of valid SWA positions in this segment
|
||||
valid_start = max(swa_start_abs, k_start)
|
||||
valid_end = min(swa_end_abs, k_end)
|
||||
if valid_end > valid_start and n_comp_local > 0:
|
||||
# SWA positions in this segment: n_comp_local to n_comp_local + (valid_end - valid_start) - 1
|
||||
swa_len_local = n_comp_local + (valid_end - valid_start)
|
||||
elif valid_end <= valid_start:
|
||||
swa_len_local = 0 # No valid SWA in this segment
|
||||
else:
|
||||
swa_len_local = swa_len # All SWA is valid
|
||||
|
||||
# If no compressed tokens and no SWA tokens in this segment, skip
|
||||
if k_seg.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# For segments that are entirely compressed (no SWA), don't apply sink bias
|
||||
has_swa = (k_start + seg_size) > n_comp
|
||||
|
||||
o_norm, lse = run_segment(
|
||||
q, k_seg, v_seg, kernel, compiled, stream,
|
||||
sink_bias=attn_sink if has_swa else None,
|
||||
n_comp_in_seg=n_comp_local,
|
||||
swa_len=swa_len_local
|
||||
)
|
||||
segment_results.append((o_norm, lse))
|
||||
|
||||
# Merge segments
|
||||
o_merged = python_kv_merge(segment_results)
|
||||
|
||||
# Compare
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o_merged.flatten().unsqueeze(0),
|
||||
ref.flatten().unsqueeze(0).float()
|
||||
).item()
|
||||
max_abs = (o_merged - ref.float()).abs().max().item()
|
||||
|
||||
status = "PASS" if cos >= 0.99 else "FAIL"
|
||||
print(f'D5c multi-tile: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
|
||||
|
||||
if cos < 0.99:
|
||||
print(f' kernel[0,:4]={o_merged[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref[0,:4].tolist()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_d5c_multitile()
|
||||
Reference in New Issue
Block a user