From c4b40dd06ca97ceae2956c0d30781109c592ef1b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 21:19:28 +0000 Subject: [PATCH] =?UTF-8?q?E2:=20CSA/HCA=20integration=20test=20=E2=80=94?= =?UTF-8?q?=20gather=20+=20FMHA=20end-to-end?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tests: - CSA: gather_compressed_kv (top-k) + gather_swa_kv + sparse FMHA - HCA: gather_all_compressed_kv + gather_swa_kv + dense FMHA - Verifies shapes, dtypes, and numerical sanity (no NaN/Inf) --- tests/e2e/test_csa_hca_integration.py | 205 ++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 tests/e2e/test_csa_hca_integration.py diff --git a/tests/e2e/test_csa_hca_integration.py b/tests/e2e/test_csa_hca_integration.py new file mode 100644 index 00000000..75410561 --- /dev/null +++ b/tests/e2e/test_csa_hca_integration.py @@ -0,0 +1,205 @@ +"""E2 test: CSA attention path — gather compressed + SWA, run sparse FMHA. + +This test creates a cache with some compressed entries, writes SWA entries, +then runs the sparse_fmha_with_swa function to verify end-to-end correctness. + +Steps: +1. Create cache pools and handle +2. Write compressed entries to the paged pool (simulating flush output) +3. Write SWA entries +4. Gather compressed + SWA KV +5. Run sparse FMHA +6. Verify output shape/dtype and numerical sanity +""" +import torch +import math + + +def test_csa_gather_and_fmha(): + """Test CSA gather + FMHA with synthetic compressed entries.""" + from dsv4.cache.schema import build_schema + from dsv4.cache.state_cache import StateCachePool + from dsv4.cache.paged_cache import PagedKVPool + from dsv4.cache.handle import LayerCacheHandle + from dsv4.model.config import DSV4Config + from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode + + config = DSV4Config.flash() + spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=FFNType.MOE, router_mode=RouterMode.DENSE) + schema = build_schema(config, spec) + + hd = config.head_dim # 512 + rd = config.rope_dim # 64 + fp8_dim = hd - rd # 448 + epb = schema.entries_per_block # 32 + + # Create pools + num_blocks = 16 + state = StateCachePool(schema, max_requests=2, device='cuda') + paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda') + + # Populate paged pool with synthetic compressed entries + # Fill first 4 blocks with random data + for b in range(4): + for e in range(epb): + # FP8 half: random uint8 + inv_scale + paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8) + # BF16 RoPE half + paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16) + paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1 + + # Build block table: [1, 4] → physical blocks 0,1,2,3 + block_table = torch.tensor([[0, 1, 2, 3] + [-1] * 12], dtype=torch.int32, device='cuda') + block_lens = torch.tensor([4], dtype=torch.int32, device='cuda') + + # Build handle + slots = torch.tensor([0], dtype=torch.int32, device='cuda') + positions = torch.tensor([128], dtype=torch.int32, device='cuda') # past the initial tokens + request_ids = torch.tensor([0], dtype=torch.int32, device='cuda') + + cache = LayerCacheHandle( + paged=paged, state=state, schema=schema, + request_slots=slots, positions=positions, request_ids=request_ids, + block_table=block_table, block_lens=block_lens, + ) + cache.num_query_heads = config.num_query_heads + + # Write SWA entries + T = 1 + raw_kv = torch.randn(T, hd, dtype=torch.bfloat16, device='cuda') + cache.write_swa(raw_kv) + + # Test gather_all_compressed_kv (HCA-style dense gather) + k_all, v_all = cache.gather_all_compressed_kv() + expected_entries = 4 * epb # 128 entries + assert k_all.shape == (1, expected_entries, hd), f"shape {k_all.shape}" + print(f" gather_all_compressed_kv: shape={k_all.shape}") + + # Test gather_compressed_kv with top-k indices + top_k = 64 + selected_indices = torch.randint(0, expected_entries, (T, top_k), dtype=torch.int64, device='cuda') + k_comp, v_comp = cache.gather_compressed_kv(selected_indices) + assert k_comp.shape == (1, top_k, hd), f"shape {k_comp.shape}" + print(f" gather_compressed_kv: shape={k_comp.shape}") + + # Test gather_swa_kv + k_swa, v_swa = cache.gather_swa_kv() + assert k_swa.shape == (1, config.sliding_window, hd), f"shape {k_swa.shape}" + print(f" gather_swa_kv: shape={k_swa.shape}") + + # Test sparse FMHA: concat [compressed, SWA] and run + from dsv4.kernels.attention.production import dsv4_attention + + n_h = config.num_query_heads # 64 + q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda') + + # Concat: [compressed, swa] — single softmax (D5c insight) + k_full = torch.cat([k_comp, k_swa], dim=1) # (1, top_k + swa_len, hd) + v_full = torch.cat([v_comp, v_swa], dim=1) + + n_comp = k_comp.shape[1] # 64 + swa_len = config.sliding_window # 128 + + output = dsv4_attention( + q, k_full, v_full, + scale=1.0 / math.sqrt(hd), + swa_len=swa_len, + is_causal=True, + n_comp=n_comp, + ) + assert output.shape == (n_h, T, hd), f"shape {output.shape}" + assert output.dtype == torch.bfloat16 + # Check not NaN/Inf + assert torch.isfinite(output.float()).all(), "Output has NaN/Inf" + print(f" sparse FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}") + + +def test_hca_gather_and_fmha(): + """Test HCA dense gather + FMHA.""" + from dsv4.cache.schema import build_schema + from dsv4.cache.state_cache import StateCachePool + from dsv4.cache.paged_cache import PagedKVPool + from dsv4.cache.handle import LayerCacheHandle + from dsv4.model.config import DSV4Config + from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode + + config = DSV4Config.pro() + # Layer 0 of Pro is HCA + spec = LayerSpec(layer_idx=0, attn=AttentionType.HCA, ffn=FFNType.MOE, router_mode=RouterMode.HASH) + schema = build_schema(config, spec) + + hd = config.head_dim # 512 + rd = config.rope_dim # 64 + fp8_dim = hd - rd + epb = schema.entries_per_block # 1 (HCA: 128/128=1 entry per block) + + num_blocks = 32 + state = StateCachePool(schema, max_requests=2, device='cuda') + paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda') + + # Fill blocks with synthetic data + for b in range(8): + for e in range(epb): + paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8) + paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16) + paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1 + + block_table = torch.tensor([[0,1,2,3,4,5,6,7] + [-1]*24], dtype=torch.int32, device='cuda') + block_lens = torch.tensor([8], dtype=torch.int32, device='cuda') + + slots = torch.tensor([0], dtype=torch.int32, device='cuda') + positions = torch.tensor([1024], dtype=torch.int32, device='cuda') + request_ids = torch.tensor([0], dtype=torch.int32, device='cuda') + + cache = LayerCacheHandle( + paged=paged, state=state, schema=schema, + request_slots=slots, positions=positions, request_ids=request_ids, + block_table=block_table, block_lens=block_lens, + ) + cache.num_query_heads = config.num_query_heads + + # Write SWA + raw_kv = torch.randn(1, hd, dtype=torch.bfloat16, device='cuda') + cache.write_swa(raw_kv) + + # Gather all compressed (HCA = dense) + k_comp, v_comp = cache.gather_all_compressed_kv() + k_swa, v_swa = cache.gather_swa_kv() + + from dsv4.kernels.attention.production import dsv4_attention + + n_h = config.num_query_heads # 128 + q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda') + k_full = torch.cat([k_comp, k_swa], dim=1) + v_full = torch.cat([v_comp, v_swa], dim=1) + + output = dsv4_attention( + q, k_full, v_full, + scale=1.0 / math.sqrt(hd), + swa_len=config.sliding_window, + is_causal=True, + n_comp=k_comp.shape[1], + ) + assert output.shape == (n_h, 1, hd) + assert torch.isfinite(output.float()).all() + print(f" HCA FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}") + + +def test(): + print("=" * 60) + print("E2: CSA/HCA Gather + FMHA Integration") + print("=" * 60) + + print("\n--- CSA: gather compressed + SWA, sparse FMHA ---") + test_csa_gather_and_fmha() + + print("\n--- HCA: gather all compressed + SWA, dense FMHA ---") + test_hca_gather_and_fmha() + + print("\n" + "=" * 60) + print("E2 CSA/HCA INTEGRATION TEST PASSED") + print("=" * 60) + + +if __name__ == '__main__': + test()