Files
nvfp4-megamoe-kernel/tests/e2e_archive/test_csa_hca_integration.py
biondizzle f3b551956d Cleanup Step 2: Archive Lineage P code, fix broken imports
- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
  - model/{dsv4,mtp,layer,layer_schedule}
  - layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
  - cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
  - kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
  - ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
  - reference/{attention,compressor,csa_attention,moe_pipeline}
  - kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
  - test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
2026-06-02 19:27:07 +00:00

206 lines
7.7 KiB
Python

"""E2 test: CSA attention path — gather compressed + SWA, run sparse FMHA.
This test creates a cache with some compressed entries, writes SWA entries,
then runs the sparse_fmha_with_swa function to verify end-to-end correctness.
Steps:
1. Create cache pools and handle
2. Write compressed entries to the paged pool (simulating flush output)
3. Write SWA entries
4. Gather compressed + SWA KV
5. Run sparse FMHA
6. Verify output shape/dtype and numerical sanity
"""
import torch
import math
def test_csa_gather_and_fmha():
"""Test CSA gather + FMHA with synthetic compressed entries."""
from dsv4.cache.schema import build_schema
from dsv4.cache.state_cache import StateCachePool
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.handle import LayerCacheHandle
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
config = DSV4Config.flash()
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
schema = build_schema(config, spec)
hd = config.head_dim # 512
rd = config.rope_dim # 64
fp8_dim = hd - rd # 448
epb = schema.entries_per_block # 32
# Create pools
num_blocks = 16
state = StateCachePool(schema, max_requests=2, device='cuda')
paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda')
# Populate paged pool with synthetic compressed entries
# Fill first 4 blocks with random data
for b in range(4):
for e in range(epb):
# FP8 half: random uint8 + inv_scale
paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8)
# BF16 RoPE half
paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16)
paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1
# Build block table: [1, 4] → physical blocks 0,1,2,3
block_table = torch.tensor([[0, 1, 2, 3] + [-1] * 12], dtype=torch.int32, device='cuda')
block_lens = torch.tensor([4], dtype=torch.int32, device='cuda')
# Build handle
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
positions = torch.tensor([128], dtype=torch.int32, device='cuda') # past the initial tokens
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
cache = LayerCacheHandle(
paged=paged, state=state, schema=schema,
request_slots=slots, positions=positions, request_ids=request_ids,
block_table=block_table, block_lens=block_lens,
)
cache.num_query_heads = config.num_query_heads
# Write SWA entries
T = 1
raw_kv = torch.randn(T, hd, dtype=torch.bfloat16, device='cuda')
cache.write_swa(raw_kv)
# Test gather_all_compressed_kv (HCA-style dense gather)
k_all, v_all = cache.gather_all_compressed_kv()
expected_entries = 4 * epb # 128 entries
assert k_all.shape == (1, expected_entries, hd), f"shape {k_all.shape}"
print(f" gather_all_compressed_kv: shape={k_all.shape}")
# Test gather_compressed_kv with top-k indices
top_k = 64
selected_indices = torch.randint(0, expected_entries, (T, top_k), dtype=torch.int64, device='cuda')
k_comp, v_comp = cache.gather_compressed_kv(selected_indices)
assert k_comp.shape == (1, top_k, hd), f"shape {k_comp.shape}"
print(f" gather_compressed_kv: shape={k_comp.shape}")
# Test gather_swa_kv
k_swa, v_swa = cache.gather_swa_kv()
assert k_swa.shape == (1, config.sliding_window, hd), f"shape {k_swa.shape}"
print(f" gather_swa_kv: shape={k_swa.shape}")
# Test sparse FMHA: concat [compressed, SWA] and run
from dsv4.kernels.attention.production import dsv4_attention
n_h = config.num_query_heads # 64
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
# Concat: [compressed, swa] — single softmax (D5c insight)
k_full = torch.cat([k_comp, k_swa], dim=1) # (1, top_k + swa_len, hd)
v_full = torch.cat([v_comp, v_swa], dim=1)
n_comp = k_comp.shape[1] # 64
swa_len = config.sliding_window # 128
output = dsv4_attention(
q, k_full, v_full,
scale=1.0 / math.sqrt(hd),
swa_len=swa_len,
is_causal=True,
n_comp=n_comp,
)
assert output.shape == (n_h, T, hd), f"shape {output.shape}"
assert output.dtype == torch.bfloat16
# Check not NaN/Inf
assert torch.isfinite(output.float()).all(), "Output has NaN/Inf"
print(f" sparse FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}")
def test_hca_gather_and_fmha():
"""Test HCA dense gather + FMHA."""
from dsv4.cache.schema import build_schema
from dsv4.cache.state_cache import StateCachePool
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.handle import LayerCacheHandle
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
config = DSV4Config.pro()
# Layer 0 of Pro is HCA
spec = LayerSpec(layer_idx=0, attn=AttentionType.HCA, ffn=FFNType.MOE, router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
hd = config.head_dim # 512
rd = config.rope_dim # 64
fp8_dim = hd - rd
epb = schema.entries_per_block # 1 (HCA: 128/128=1 entry per block)
num_blocks = 32
state = StateCachePool(schema, max_requests=2, device='cuda')
paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda')
# Fill blocks with synthetic data
for b in range(8):
for e in range(epb):
paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8)
paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16)
paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1
block_table = torch.tensor([[0,1,2,3,4,5,6,7] + [-1]*24], dtype=torch.int32, device='cuda')
block_lens = torch.tensor([8], dtype=torch.int32, device='cuda')
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
positions = torch.tensor([1024], dtype=torch.int32, device='cuda')
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
cache = LayerCacheHandle(
paged=paged, state=state, schema=schema,
request_slots=slots, positions=positions, request_ids=request_ids,
block_table=block_table, block_lens=block_lens,
)
cache.num_query_heads = config.num_query_heads
# Write SWA
raw_kv = torch.randn(1, hd, dtype=torch.bfloat16, device='cuda')
cache.write_swa(raw_kv)
# Gather all compressed (HCA = dense)
k_comp, v_comp = cache.gather_all_compressed_kv()
k_swa, v_swa = cache.gather_swa_kv()
from dsv4.kernels.attention.production import dsv4_attention
n_h = config.num_query_heads # 128
q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
k_full = torch.cat([k_comp, k_swa], dim=1)
v_full = torch.cat([v_comp, v_swa], dim=1)
output = dsv4_attention(
q, k_full, v_full,
scale=1.0 / math.sqrt(hd),
swa_len=config.sliding_window,
is_causal=True,
n_comp=k_comp.shape[1],
)
assert output.shape == (n_h, 1, hd)
assert torch.isfinite(output.float()).all()
print(f" HCA FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}")
def test():
print("=" * 60)
print("E2: CSA/HCA Gather + FMHA Integration")
print("=" * 60)
print("\n--- CSA: gather compressed + SWA, sparse FMHA ---")
test_csa_gather_and_fmha()
print("\n--- HCA: gather all compressed + SWA, dense FMHA ---")
test_hca_gather_and_fmha()
print("\n" + "=" * 60)
print("E2 CSA/HCA INTEGRATION TEST PASSED")
print("=" * 60)
if __name__ == '__main__':
test()