"""E2 test: CSA attention path — gather compressed + SWA, run sparse FMHA. This test creates a cache with some compressed entries, writes SWA entries, then runs the sparse_fmha_with_swa function to verify end-to-end correctness. Steps: 1. Create cache pools and handle 2. Write compressed entries to the paged pool (simulating flush output) 3. Write SWA entries 4. Gather compressed + SWA KV 5. Run sparse FMHA 6. Verify output shape/dtype and numerical sanity """ import torch import math def test_csa_gather_and_fmha(): """Test CSA gather + FMHA with synthetic compressed entries.""" from dsv4.cache.schema import build_schema from dsv4.cache.state_cache import StateCachePool from dsv4.cache.paged_cache import PagedKVPool from dsv4.cache.handle import LayerCacheHandle from dsv4.model.config import DSV4Config from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode config = DSV4Config.flash() spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=FFNType.MOE, router_mode=RouterMode.DENSE) schema = build_schema(config, spec) hd = config.head_dim # 512 rd = config.rope_dim # 64 fp8_dim = hd - rd # 448 epb = schema.entries_per_block # 32 # Create pools num_blocks = 16 state = StateCachePool(schema, max_requests=2, device='cuda') paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda') # Populate paged pool with synthetic compressed entries # Fill first 4 blocks with random data for b in range(4): for e in range(epb): # FP8 half: random uint8 + inv_scale paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8) # BF16 RoPE half paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16) paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1 # Build block table: [1, 4] → physical blocks 0,1,2,3 block_table = torch.tensor([[0, 1, 2, 3] + [-1] * 12], dtype=torch.int32, device='cuda') block_lens = torch.tensor([4], dtype=torch.int32, device='cuda') # Build handle slots = torch.tensor([0], dtype=torch.int32, device='cuda') positions = torch.tensor([128], dtype=torch.int32, device='cuda') # past the initial tokens request_ids = torch.tensor([0], dtype=torch.int32, device='cuda') cache = LayerCacheHandle( paged=paged, state=state, schema=schema, request_slots=slots, positions=positions, request_ids=request_ids, block_table=block_table, block_lens=block_lens, ) cache.num_query_heads = config.num_query_heads # Write SWA entries T = 1 raw_kv = torch.randn(T, hd, dtype=torch.bfloat16, device='cuda') cache.write_swa(raw_kv) # Test gather_all_compressed_kv (HCA-style dense gather) k_all, v_all = cache.gather_all_compressed_kv() expected_entries = 4 * epb # 128 entries assert k_all.shape == (1, expected_entries, hd), f"shape {k_all.shape}" print(f" gather_all_compressed_kv: shape={k_all.shape}") # Test gather_compressed_kv with top-k indices top_k = 64 selected_indices = torch.randint(0, expected_entries, (T, top_k), dtype=torch.int64, device='cuda') k_comp, v_comp = cache.gather_compressed_kv(selected_indices) assert k_comp.shape == (1, top_k, hd), f"shape {k_comp.shape}" print(f" gather_compressed_kv: shape={k_comp.shape}") # Test gather_swa_kv k_swa, v_swa = cache.gather_swa_kv() assert k_swa.shape == (1, config.sliding_window, hd), f"shape {k_swa.shape}" print(f" gather_swa_kv: shape={k_swa.shape}") # Test sparse FMHA: concat [compressed, SWA] and run from dsv4.kernels.attention.production import dsv4_attention n_h = config.num_query_heads # 64 q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda') # Concat: [compressed, swa] — single softmax (D5c insight) k_full = torch.cat([k_comp, k_swa], dim=1) # (1, top_k + swa_len, hd) v_full = torch.cat([v_comp, v_swa], dim=1) n_comp = k_comp.shape[1] # 64 swa_len = config.sliding_window # 128 output = dsv4_attention( q, k_full, v_full, scale=1.0 / math.sqrt(hd), swa_len=swa_len, is_causal=True, n_comp=n_comp, ) assert output.shape == (n_h, T, hd), f"shape {output.shape}" assert output.dtype == torch.bfloat16 # Check not NaN/Inf assert torch.isfinite(output.float()).all(), "Output has NaN/Inf" print(f" sparse FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}") def test_hca_gather_and_fmha(): """Test HCA dense gather + FMHA.""" from dsv4.cache.schema import build_schema from dsv4.cache.state_cache import StateCachePool from dsv4.cache.paged_cache import PagedKVPool from dsv4.cache.handle import LayerCacheHandle from dsv4.model.config import DSV4Config from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode config = DSV4Config.pro() # Layer 0 of Pro is HCA spec = LayerSpec(layer_idx=0, attn=AttentionType.HCA, ffn=FFNType.MOE, router_mode=RouterMode.HASH) schema = build_schema(config, spec) hd = config.head_dim # 512 rd = config.rope_dim # 64 fp8_dim = hd - rd epb = schema.entries_per_block # 1 (HCA: 128/128=1 entry per block) num_blocks = 32 state = StateCachePool(schema, max_requests=2, device='cuda') paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda') # Fill blocks with synthetic data for b in range(8): for e in range(epb): paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8) paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16) paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1 block_table = torch.tensor([[0,1,2,3,4,5,6,7] + [-1]*24], dtype=torch.int32, device='cuda') block_lens = torch.tensor([8], dtype=torch.int32, device='cuda') slots = torch.tensor([0], dtype=torch.int32, device='cuda') positions = torch.tensor([1024], dtype=torch.int32, device='cuda') request_ids = torch.tensor([0], dtype=torch.int32, device='cuda') cache = LayerCacheHandle( paged=paged, state=state, schema=schema, request_slots=slots, positions=positions, request_ids=request_ids, block_table=block_table, block_lens=block_lens, ) cache.num_query_heads = config.num_query_heads # Write SWA raw_kv = torch.randn(1, hd, dtype=torch.bfloat16, device='cuda') cache.write_swa(raw_kv) # Gather all compressed (HCA = dense) k_comp, v_comp = cache.gather_all_compressed_kv() k_swa, v_swa = cache.gather_swa_kv() from dsv4.kernels.attention.production import dsv4_attention n_h = config.num_query_heads # 128 q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda') k_full = torch.cat([k_comp, k_swa], dim=1) v_full = torch.cat([v_comp, v_swa], dim=1) output = dsv4_attention( q, k_full, v_full, scale=1.0 / math.sqrt(hd), swa_len=config.sliding_window, is_causal=True, n_comp=k_comp.shape[1], ) assert output.shape == (n_h, 1, hd) assert torch.isfinite(output.float()).all() print(f" HCA FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}") def test(): print("=" * 60) print("E2: CSA/HCA Gather + FMHA Integration") print("=" * 60) print("\n--- CSA: gather compressed + SWA, sparse FMHA ---") test_csa_gather_and_fmha() print("\n--- HCA: gather all compressed + SWA, dense FMHA ---") test_hca_gather_and_fmha() print("\n" + "=" * 60) print("E2 CSA/HCA INTEGRATION TEST PASSED") print("=" * 60) if __name__ == '__main__': test()