E2: CSA/HCA integration test — gather + FMHA end-to-end
Tests: - CSA: gather_compressed_kv (top-k) + gather_swa_kv + sparse FMHA - HCA: gather_all_compressed_kv + gather_swa_kv + dense FMHA - Verifies shapes, dtypes, and numerical sanity (no NaN/Inf)
This commit is contained in:
205
tests/e2e/test_csa_hca_integration.py
Normal file
205
tests/e2e/test_csa_hca_integration.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""E2 test: CSA attention path — gather compressed + SWA, run sparse FMHA.
|
||||
|
||||
This test creates a cache with some compressed entries, writes SWA entries,
|
||||
then runs the sparse_fmha_with_swa function to verify end-to-end correctness.
|
||||
|
||||
Steps:
|
||||
1. Create cache pools and handle
|
||||
2. Write compressed entries to the paged pool (simulating flush output)
|
||||
3. Write SWA entries
|
||||
4. Gather compressed + SWA KV
|
||||
5. Run sparse FMHA
|
||||
6. Verify output shape/dtype and numerical sanity
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
def test_csa_gather_and_fmha():
|
||||
"""Test CSA gather + FMHA with synthetic compressed entries."""
|
||||
from dsv4.cache.schema import build_schema
|
||||
from dsv4.cache.state_cache import StateCachePool
|
||||
from dsv4.cache.paged_cache import PagedKVPool
|
||||
from dsv4.cache.handle import LayerCacheHandle
|
||||
from dsv4.model.config import DSV4Config
|
||||
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
|
||||
|
||||
config = DSV4Config.flash()
|
||||
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
|
||||
schema = build_schema(config, spec)
|
||||
|
||||
hd = config.head_dim # 512
|
||||
rd = config.rope_dim # 64
|
||||
fp8_dim = hd - rd # 448
|
||||
epb = schema.entries_per_block # 32
|
||||
|
||||
# Create pools
|
||||
num_blocks = 16
|
||||
state = StateCachePool(schema, max_requests=2, device='cuda')
|
||||
paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda')
|
||||
|
||||
# Populate paged pool with synthetic compressed entries
|
||||
# Fill first 4 blocks with random data
|
||||
for b in range(4):
|
||||
for e in range(epb):
|
||||
# FP8 half: random uint8 + inv_scale
|
||||
paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8)
|
||||
# BF16 RoPE half
|
||||
paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16)
|
||||
paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1
|
||||
|
||||
# Build block table: [1, 4] → physical blocks 0,1,2,3
|
||||
block_table = torch.tensor([[0, 1, 2, 3] + [-1] * 12], dtype=torch.int32, device='cuda')
|
||||
block_lens = torch.tensor([4], dtype=torch.int32, device='cuda')
|
||||
|
||||
# Build handle
|
||||
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
|
||||
positions = torch.tensor([128], dtype=torch.int32, device='cuda') # past the initial tokens
|
||||
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
|
||||
|
||||
cache = LayerCacheHandle(
|
||||
paged=paged, state=state, schema=schema,
|
||||
request_slots=slots, positions=positions, request_ids=request_ids,
|
||||
block_table=block_table, block_lens=block_lens,
|
||||
)
|
||||
cache.num_query_heads = config.num_query_heads
|
||||
|
||||
# Write SWA entries
|
||||
T = 1
|
||||
raw_kv = torch.randn(T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
cache.write_swa(raw_kv)
|
||||
|
||||
# Test gather_all_compressed_kv (HCA-style dense gather)
|
||||
k_all, v_all = cache.gather_all_compressed_kv()
|
||||
expected_entries = 4 * epb # 128 entries
|
||||
assert k_all.shape == (1, expected_entries, hd), f"shape {k_all.shape}"
|
||||
print(f" gather_all_compressed_kv: shape={k_all.shape}")
|
||||
|
||||
# Test gather_compressed_kv with top-k indices
|
||||
top_k = 64
|
||||
selected_indices = torch.randint(0, expected_entries, (T, top_k), dtype=torch.int64, device='cuda')
|
||||
k_comp, v_comp = cache.gather_compressed_kv(selected_indices)
|
||||
assert k_comp.shape == (1, top_k, hd), f"shape {k_comp.shape}"
|
||||
print(f" gather_compressed_kv: shape={k_comp.shape}")
|
||||
|
||||
# Test gather_swa_kv
|
||||
k_swa, v_swa = cache.gather_swa_kv()
|
||||
assert k_swa.shape == (1, config.sliding_window, hd), f"shape {k_swa.shape}"
|
||||
print(f" gather_swa_kv: shape={k_swa.shape}")
|
||||
|
||||
# Test sparse FMHA: concat [compressed, SWA] and run
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
|
||||
n_h = config.num_query_heads # 64
|
||||
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Concat: [compressed, swa] — single softmax (D5c insight)
|
||||
k_full = torch.cat([k_comp, k_swa], dim=1) # (1, top_k + swa_len, hd)
|
||||
v_full = torch.cat([v_comp, v_swa], dim=1)
|
||||
|
||||
n_comp = k_comp.shape[1] # 64
|
||||
swa_len = config.sliding_window # 128
|
||||
|
||||
output = dsv4_attention(
|
||||
q, k_full, v_full,
|
||||
scale=1.0 / math.sqrt(hd),
|
||||
swa_len=swa_len,
|
||||
is_causal=True,
|
||||
n_comp=n_comp,
|
||||
)
|
||||
assert output.shape == (n_h, T, hd), f"shape {output.shape}"
|
||||
assert output.dtype == torch.bfloat16
|
||||
# Check not NaN/Inf
|
||||
assert torch.isfinite(output.float()).all(), "Output has NaN/Inf"
|
||||
print(f" sparse FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}")
|
||||
|
||||
|
||||
def test_hca_gather_and_fmha():
|
||||
"""Test HCA dense gather + FMHA."""
|
||||
from dsv4.cache.schema import build_schema
|
||||
from dsv4.cache.state_cache import StateCachePool
|
||||
from dsv4.cache.paged_cache import PagedKVPool
|
||||
from dsv4.cache.handle import LayerCacheHandle
|
||||
from dsv4.model.config import DSV4Config
|
||||
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
|
||||
|
||||
config = DSV4Config.pro()
|
||||
# Layer 0 of Pro is HCA
|
||||
spec = LayerSpec(layer_idx=0, attn=AttentionType.HCA, ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
||||
schema = build_schema(config, spec)
|
||||
|
||||
hd = config.head_dim # 512
|
||||
rd = config.rope_dim # 64
|
||||
fp8_dim = hd - rd
|
||||
epb = schema.entries_per_block # 1 (HCA: 128/128=1 entry per block)
|
||||
|
||||
num_blocks = 32
|
||||
state = StateCachePool(schema, max_requests=2, device='cuda')
|
||||
paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda')
|
||||
|
||||
# Fill blocks with synthetic data
|
||||
for b in range(8):
|
||||
for e in range(epb):
|
||||
paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8)
|
||||
paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16)
|
||||
paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1
|
||||
|
||||
block_table = torch.tensor([[0,1,2,3,4,5,6,7] + [-1]*24], dtype=torch.int32, device='cuda')
|
||||
block_lens = torch.tensor([8], dtype=torch.int32, device='cuda')
|
||||
|
||||
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
|
||||
positions = torch.tensor([1024], dtype=torch.int32, device='cuda')
|
||||
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
|
||||
|
||||
cache = LayerCacheHandle(
|
||||
paged=paged, state=state, schema=schema,
|
||||
request_slots=slots, positions=positions, request_ids=request_ids,
|
||||
block_table=block_table, block_lens=block_lens,
|
||||
)
|
||||
cache.num_query_heads = config.num_query_heads
|
||||
|
||||
# Write SWA
|
||||
raw_kv = torch.randn(1, hd, dtype=torch.bfloat16, device='cuda')
|
||||
cache.write_swa(raw_kv)
|
||||
|
||||
# Gather all compressed (HCA = dense)
|
||||
k_comp, v_comp = cache.gather_all_compressed_kv()
|
||||
k_swa, v_swa = cache.gather_swa_kv()
|
||||
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
|
||||
n_h = config.num_query_heads # 128
|
||||
q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k_full = torch.cat([k_comp, k_swa], dim=1)
|
||||
v_full = torch.cat([v_comp, v_swa], dim=1)
|
||||
|
||||
output = dsv4_attention(
|
||||
q, k_full, v_full,
|
||||
scale=1.0 / math.sqrt(hd),
|
||||
swa_len=config.sliding_window,
|
||||
is_causal=True,
|
||||
n_comp=k_comp.shape[1],
|
||||
)
|
||||
assert output.shape == (n_h, 1, hd)
|
||||
assert torch.isfinite(output.float()).all()
|
||||
print(f" HCA FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}")
|
||||
|
||||
|
||||
def test():
|
||||
print("=" * 60)
|
||||
print("E2: CSA/HCA Gather + FMHA Integration")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n--- CSA: gather compressed + SWA, sparse FMHA ---")
|
||||
test_csa_gather_and_fmha()
|
||||
|
||||
print("\n--- HCA: gather all compressed + SWA, dense FMHA ---")
|
||||
test_hca_gather_and_fmha()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("E2 CSA/HCA INTEGRATION TEST PASSED")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user