From 3abcc7ff096d2fe36f92fd985ec0b840191cfa17 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 15:20:45 +0000 Subject: [PATCH] D5c: multi-tile test using Python KV merge with sink bias --- tests/unit/test_d5c_multitile.py | 238 +++++++++++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 tests/unit/test_d5c_multitile.py diff --git a/tests/unit/test_d5c_multitile.py b/tests/unit/test_d5c_multitile.py new file mode 100644 index 00000000..b84659ca --- /dev/null +++ b/tests/unit/test_d5c_multitile.py @@ -0,0 +1,238 @@ +""" +FMHA D5c: Sink bias + Python KV merge for multi-tile (s_k > 128). + +Verifies the full DSV4 attention pipeline: +1. Concatenate KV: [compressed_K; swa_K] (total s_k > 128) +2. Split into 128-token segments +3. Run FMHA per segment (with sink bias on SWA positions, D3/D4 masking) +4. Merge segments using Python KV merge formula: + O = sum_i(exp(lse_i) * O_i_norm) / sum_i(exp(lse_i)) +5. Normalize using row_sum + +This is the production path for DSV4 Pro (s_k=1152, 9 KV tiles) +until the D1.5 TMEM round-trip fix enables in-kernel O rescale. + +Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_multitile.py +""" +import torch +import math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa, + attn_sink, scale, swa_len, is_causal=False): + """FP32 reference: single softmax over combined KV with sink bias on SWA.""" + m, hd = q.shape + n_comp = k_comp.shape[0] + n_swa = k_swa.shape[0] + k_combined = torch.cat([k_comp, k_swa], dim=0) + v_combined = torch.cat([v_comp, v_swa], dim=0) + scores = torch.matmul(q.float(), k_combined.float().T) * scale + scores[:, n_comp:] += attn_sink + if swa_len < n_swa: + scores[:, n_comp + swa_len:] = float('-inf') + if is_causal: + for i in range(m): + for j in range(n_swa): + if j > i: + scores[i, n_comp + j] = float('-inf') + max_s = scores.max(dim=-1, keepdim=True).values + exp_s = (scores - max_s).exp() + sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30) + o = torch.matmul(exp_s / sum_s, v_combined.float()) + return o.to(torch.bfloat16) + + +def run_segment(q, k_seg, v_seg, kernel, compiled, stream, + sink_bias=None, n_comp_in_seg=0, swa_len=999999): + """Run one 128-token segment of FMHA, return normalized O and LSE.""" + m = q.shape[0] + hd = v_seg.shape[1] + pv_n_tile = kernel.pv_n_tile + + # Allocate per-segment outputs + o_seg = torch.zeros(m, hd, dtype=torch.bfloat16, device='cuda') + lse_seg = torch.zeros(m, dtype=torch.float32, device='cuda') + + def to_cute(t): + return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t)) + + for nt in range(kernel.n_pv_tiles): + v_start = nt * pv_n_tile + v_end = v_start + pv_n_tile + v_tile = v_seg[:, v_start:v_end].contiguous() + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tile = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + rs_tile = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + mQ = to_cute(q.unsqueeze(-1)) + mK = to_cute(k_seg.unsqueeze(-1)) + mV = to_cute(v_tile.unsqueeze(-1)) + mC = to_cute(c_tile) + mLSE = to_cute(lse_tile) + mRS = to_cute(rs_tile) + + if sink_bias is not None: + mSB = to_cute(sink_bias) + compiled(mQ, mK, mV, mC, stream, mLSE, + swa_len=swa_len, sink_bias=mSB, row_sums=mRS) + else: + compiled(mQ, mK, mV, mC, stream, mLSE, + swa_len=swa_len, row_sums=mRS) + + torch.cuda.synchronize() + o_seg[:, v_start:v_end] = c_tile[:, :, 0] + if nt == 0: + lse_seg = lse_tile[:, 0, 0].clone() + rs_seg = rs_tile[:, 0, 0].clone() + # Note: LSE and row_sum are the same across PV tiles (same softmax) + + # Normalize using row_sum + o_norm = o_seg.float() / rs_seg.unsqueeze(1).clamp(min=1e-30) + return o_norm, lse_seg + + +def python_kv_merge(segment_results): + """Merge multiple FMHA segments using Python KV merge formula. + + O = sum_i(exp(lse_i) * O_i_norm) / sum_i(exp(lse_i)) + """ + # Compute max LSE for numerical stability + lse_stack = torch.stack([r[1] for r in segment_results], dim=0) # (n_seg, m) + lse_max = lse_stack.max(dim=0).values # (m,) + + numerator = torch.zeros_like(segment_results[0][0]) # (m, hd) + denominator = torch.zeros(lse_max.shape[0], dtype=torch.float32, device=lse_max.device) + + for o_norm, lse in segment_results: + exp_lse = (lse - lse_max).exp() # (m,) + numerator += exp_lse.unsqueeze(1) * o_norm.float() + denominator += exp_lse + + o_merged = numerator / denominator.unsqueeze(1).clamp(min=1e-30) + return o_merged + + +def test_d5c_multitile(): + print("=== D5c Multi-Tile: Sink Bias + Python KV Merge ===\n") + + hd = 64 + m = 128 + n_comp = 96 # compressed KV tokens + n_swa = 160 # SWA tokens + n_total = n_comp + n_swa # 256, 2 KV tiles + swa_len = 100 # valid SWA fill (within the 160 SWA window) + scale = 1.0 / math.sqrt(hd) + torch.manual_seed(42) + + q = torch.randn(m, hd, dtype=torch.bfloat16, device='cuda') + k_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda') + v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda') + k_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda') + v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda') + + attn_sink_val = 0.5 + attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda') + + # Reference + ref = reference_combined_attention( + q, k_comp, v_comp, k_swa, v_swa, + attn_sink_val, scale, swa_len + ) + + # Combined KV + k_combined = torch.cat([k_comp, k_swa], dim=0) # (256, hd) + v_combined = torch.cat([v_comp, v_swa], dim=0) # (256, hd) + + # Split into 128-token segments and run FMHA per segment + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + seg_size = 128 + n_segs = (n_total + seg_size - 1) // seg_size + + # Compile kernel for s_k=128 (single KV tile, no O rescale) + kernel = FmhaKernel(head_dim=hd, s_k=seg_size, normalize=False, + apply_swa_mask=True, is_causal=False, n_comp=n_comp) + + # Pre-compile with dummy data + def to_cute(t): + return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t)) + _q = q.unsqueeze(-1) + _k = k_combined[:seg_size].unsqueeze(-1) + _v = v_combined[:seg_size, :kernel.pv_n_tile].contiguous().unsqueeze(-1) + _c = torch.zeros(m, kernel.pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + _lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + _rs = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + _mQ = to_cute(_q); _mK = to_cute(_k); _mV = to_cute(_v) + _mC = to_cute(_c); _mLSE = to_cute(_lse); _mRS = to_cute(_rs) + _mSB = to_cute(attn_sink) + compiled = cute.compile(kernel, _mQ, _mK, _mV, _mC, stream, _mLSE, + swa_len=swa_len, sink_bias=_mSB, row_sums=_mRS) + + # Run each segment + segment_results = [] + for seg_idx in range(n_segs): + k_start = seg_idx * seg_size + k_end = min(k_start + seg_size, n_total) + k_seg = k_combined[k_start:k_end] + v_seg = v_combined[k_start:k_end] + + # Determine sink bias and swa_len for this segment + # Sink bias applies to positions >= n_comp (absolute) + # In this segment, the first local position maps to absolute position k_start + # So sink bias applies to local positions >= max(0, n_comp - k_start) + n_comp_local = max(0, n_comp - k_start) # compressed tokens in this segment + # swa_len: how many SWA positions are valid in this segment + # Absolute valid range: n_comp to n_comp + swa_len - 1 + swa_start_abs = n_comp + swa_end_abs = n_comp + swa_len + # This segment covers absolute positions k_start to k_end - 1 + # Valid SWA in this segment: max(swa_start_abs, k_start) to min(swa_end_abs, k_end) - 1 + # swa_len_local: number of valid SWA positions in this segment + valid_start = max(swa_start_abs, k_start) + valid_end = min(swa_end_abs, k_end) + if valid_end > valid_start and n_comp_local > 0: + # SWA positions in this segment: n_comp_local to n_comp_local + (valid_end - valid_start) - 1 + swa_len_local = n_comp_local + (valid_end - valid_start) + elif valid_end <= valid_start: + swa_len_local = 0 # No valid SWA in this segment + else: + swa_len_local = swa_len # All SWA is valid + + # If no compressed tokens and no SWA tokens in this segment, skip + if k_seg.shape[0] == 0: + continue + + # For segments that are entirely compressed (no SWA), don't apply sink bias + has_swa = (k_start + seg_size) > n_comp + + o_norm, lse = run_segment( + q, k_seg, v_seg, kernel, compiled, stream, + sink_bias=attn_sink if has_swa else None, + n_comp_in_seg=n_comp_local, + swa_len=swa_len_local + ) + segment_results.append((o_norm, lse)) + + # Merge segments + o_merged = python_kv_merge(segment_results) + + # Compare + cos = torch.nn.functional.cosine_similarity( + o_merged.flatten().unsqueeze(0), + ref.flatten().unsqueeze(0).float() + ).item() + max_abs = (o_merged - ref.float()).abs().max().item() + + status = "PASS" if cos >= 0.99 else "FAIL" + print(f'D5c multi-tile: cos {cos:.6f} max_abs {max_abs:.4f} {status}') + + if cos < 0.99: + print(f' kernel[0,:4]={o_merged[0,:4].tolist()}') + print(f' ref[0,:4]={ref[0,:4].tolist()}') + + +if __name__ == '__main__': + test_d5c_multitile()