Files
nvfp4-megamoe-kernel/tests/e2e/test_csa_hca_integration.py
biondizzle c4b40dd06c E2: CSA/HCA integration test — gather + FMHA end-to-end
Tests:
- CSA: gather_compressed_kv (top-k) + gather_swa_kv + sparse FMHA
- HCA: gather_all_compressed_kv + gather_swa_kv + dense FMHA
- Verifies shapes, dtypes, and numerical sanity (no NaN/Inf)
2026-05-30 21:19:28 +00:00

206 lines
7.7 KiB
Python

"""E2 test: CSA attention path — gather compressed + SWA, run sparse FMHA.
This test creates a cache with some compressed entries, writes SWA entries,
then runs the sparse_fmha_with_swa function to verify end-to-end correctness.
Steps:
1. Create cache pools and handle
2. Write compressed entries to the paged pool (simulating flush output)
3. Write SWA entries
4. Gather compressed + SWA KV
5. Run sparse FMHA
6. Verify output shape/dtype and numerical sanity
"""
import torch
import math
def test_csa_gather_and_fmha():
"""Test CSA gather + FMHA with synthetic compressed entries."""
from dsv4.cache.schema import build_schema
from dsv4.cache.state_cache import StateCachePool
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.handle import LayerCacheHandle
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
config = DSV4Config.flash()
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
schema = build_schema(config, spec)
hd = config.head_dim # 512
rd = config.rope_dim # 64
fp8_dim = hd - rd # 448
epb = schema.entries_per_block # 32
# Create pools
num_blocks = 16
state = StateCachePool(schema, max_requests=2, device='cuda')
paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda')
# Populate paged pool with synthetic compressed entries
# Fill first 4 blocks with random data
for b in range(4):
for e in range(epb):
# FP8 half: random uint8 + inv_scale
paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8)
# BF16 RoPE half
paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16)
paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1
# Build block table: [1, 4] → physical blocks 0,1,2,3
block_table = torch.tensor([[0, 1, 2, 3] + [-1] * 12], dtype=torch.int32, device='cuda')
block_lens = torch.tensor([4], dtype=torch.int32, device='cuda')
# Build handle
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
positions = torch.tensor([128], dtype=torch.int32, device='cuda') # past the initial tokens
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
cache = LayerCacheHandle(
paged=paged, state=state, schema=schema,
request_slots=slots, positions=positions, request_ids=request_ids,
block_table=block_table, block_lens=block_lens,
)
cache.num_query_heads = config.num_query_heads
# Write SWA entries
T = 1
raw_kv = torch.randn(T, hd, dtype=torch.bfloat16, device='cuda')
cache.write_swa(raw_kv)
# Test gather_all_compressed_kv (HCA-style dense gather)
k_all, v_all = cache.gather_all_compressed_kv()
expected_entries = 4 * epb # 128 entries
assert k_all.shape == (1, expected_entries, hd), f"shape {k_all.shape}"
print(f" gather_all_compressed_kv: shape={k_all.shape}")
# Test gather_compressed_kv with top-k indices
top_k = 64
selected_indices = torch.randint(0, expected_entries, (T, top_k), dtype=torch.int64, device='cuda')
k_comp, v_comp = cache.gather_compressed_kv(selected_indices)
assert k_comp.shape == (1, top_k, hd), f"shape {k_comp.shape}"
print(f" gather_compressed_kv: shape={k_comp.shape}")
# Test gather_swa_kv
k_swa, v_swa = cache.gather_swa_kv()
assert k_swa.shape == (1, config.sliding_window, hd), f"shape {k_swa.shape}"
print(f" gather_swa_kv: shape={k_swa.shape}")
# Test sparse FMHA: concat [compressed, SWA] and run
from dsv4.kernels.attention.production import dsv4_attention
n_h = config.num_query_heads # 64
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
# Concat: [compressed, swa] — single softmax (D5c insight)
k_full = torch.cat([k_comp, k_swa], dim=1) # (1, top_k + swa_len, hd)
v_full = torch.cat([v_comp, v_swa], dim=1)
n_comp = k_comp.shape[1] # 64
swa_len = config.sliding_window # 128
output = dsv4_attention(
q, k_full, v_full,
scale=1.0 / math.sqrt(hd),
swa_len=swa_len,
is_causal=True,
n_comp=n_comp,
)
assert output.shape == (n_h, T, hd), f"shape {output.shape}"
assert output.dtype == torch.bfloat16
# Check not NaN/Inf
assert torch.isfinite(output.float()).all(), "Output has NaN/Inf"
print(f" sparse FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}")
def test_hca_gather_and_fmha():
"""Test HCA dense gather + FMHA."""
from dsv4.cache.schema import build_schema
from dsv4.cache.state_cache import StateCachePool
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.handle import LayerCacheHandle
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
config = DSV4Config.pro()
# Layer 0 of Pro is HCA
spec = LayerSpec(layer_idx=0, attn=AttentionType.HCA, ffn=FFNType.MOE, router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
hd = config.head_dim # 512
rd = config.rope_dim # 64
fp8_dim = hd - rd
epb = schema.entries_per_block # 1 (HCA: 128/128=1 entry per block)
num_blocks = 32
state = StateCachePool(schema, max_requests=2, device='cuda')
paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda')
# Fill blocks with synthetic data
for b in range(8):
for e in range(epb):
paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8)
paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16)
paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1
block_table = torch.tensor([[0,1,2,3,4,5,6,7] + [-1]*24], dtype=torch.int32, device='cuda')
block_lens = torch.tensor([8], dtype=torch.int32, device='cuda')
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
positions = torch.tensor([1024], dtype=torch.int32, device='cuda')
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
cache = LayerCacheHandle(
paged=paged, state=state, schema=schema,
request_slots=slots, positions=positions, request_ids=request_ids,
block_table=block_table, block_lens=block_lens,
)
cache.num_query_heads = config.num_query_heads
# Write SWA
raw_kv = torch.randn(1, hd, dtype=torch.bfloat16, device='cuda')
cache.write_swa(raw_kv)
# Gather all compressed (HCA = dense)
k_comp, v_comp = cache.gather_all_compressed_kv()
k_swa, v_swa = cache.gather_swa_kv()
from dsv4.kernels.attention.production import dsv4_attention
n_h = config.num_query_heads # 128
q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
k_full = torch.cat([k_comp, k_swa], dim=1)
v_full = torch.cat([v_comp, v_swa], dim=1)
output = dsv4_attention(
q, k_full, v_full,
scale=1.0 / math.sqrt(hd),
swa_len=config.sliding_window,
is_causal=True,
n_comp=k_comp.shape[1],
)
assert output.shape == (n_h, 1, hd)
assert torch.isfinite(output.float()).all()
print(f" HCA FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}")
def test():
print("=" * 60)
print("E2: CSA/HCA Gather + FMHA Integration")
print("=" * 60)
print("\n--- CSA: gather compressed + SWA, sparse FMHA ---")
test_csa_gather_and_fmha()
print("\n--- HCA: gather all compressed + SWA, dense FMHA ---")
test_hca_gather_and_fmha()
print("\n" + "=" * 60)
print("E2 CSA/HCA INTEGRATION TEST PASSED")
print("=" * 60)
if __name__ == '__main__':
test()