Tests: - CSA: gather_compressed_kv (top-k) + gather_swa_kv + sparse FMHA - HCA: gather_all_compressed_kv + gather_swa_kv + dense FMHA - Verifies shapes, dtypes, and numerical sanity (no NaN/Inf)
206 lines
7.7 KiB
Python
206 lines
7.7 KiB
Python
"""E2 test: CSA attention path — gather compressed + SWA, run sparse FMHA.
|
|
|
|
This test creates a cache with some compressed entries, writes SWA entries,
|
|
then runs the sparse_fmha_with_swa function to verify end-to-end correctness.
|
|
|
|
Steps:
|
|
1. Create cache pools and handle
|
|
2. Write compressed entries to the paged pool (simulating flush output)
|
|
3. Write SWA entries
|
|
4. Gather compressed + SWA KV
|
|
5. Run sparse FMHA
|
|
6. Verify output shape/dtype and numerical sanity
|
|
"""
|
|
import torch
|
|
import math
|
|
|
|
|
|
def test_csa_gather_and_fmha():
|
|
"""Test CSA gather + FMHA with synthetic compressed entries."""
|
|
from dsv4.cache.schema import build_schema
|
|
from dsv4.cache.state_cache import StateCachePool
|
|
from dsv4.cache.paged_cache import PagedKVPool
|
|
from dsv4.cache.handle import LayerCacheHandle
|
|
from dsv4.model.config import DSV4Config
|
|
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
|
|
|
|
config = DSV4Config.flash()
|
|
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
|
|
schema = build_schema(config, spec)
|
|
|
|
hd = config.head_dim # 512
|
|
rd = config.rope_dim # 64
|
|
fp8_dim = hd - rd # 448
|
|
epb = schema.entries_per_block # 32
|
|
|
|
# Create pools
|
|
num_blocks = 16
|
|
state = StateCachePool(schema, max_requests=2, device='cuda')
|
|
paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda')
|
|
|
|
# Populate paged pool with synthetic compressed entries
|
|
# Fill first 4 blocks with random data
|
|
for b in range(4):
|
|
for e in range(epb):
|
|
# FP8 half: random uint8 + inv_scale
|
|
paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8)
|
|
# BF16 RoPE half
|
|
paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16)
|
|
paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1
|
|
|
|
# Build block table: [1, 4] → physical blocks 0,1,2,3
|
|
block_table = torch.tensor([[0, 1, 2, 3] + [-1] * 12], dtype=torch.int32, device='cuda')
|
|
block_lens = torch.tensor([4], dtype=torch.int32, device='cuda')
|
|
|
|
# Build handle
|
|
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
positions = torch.tensor([128], dtype=torch.int32, device='cuda') # past the initial tokens
|
|
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
|
|
cache = LayerCacheHandle(
|
|
paged=paged, state=state, schema=schema,
|
|
request_slots=slots, positions=positions, request_ids=request_ids,
|
|
block_table=block_table, block_lens=block_lens,
|
|
)
|
|
cache.num_query_heads = config.num_query_heads
|
|
|
|
# Write SWA entries
|
|
T = 1
|
|
raw_kv = torch.randn(T, hd, dtype=torch.bfloat16, device='cuda')
|
|
cache.write_swa(raw_kv)
|
|
|
|
# Test gather_all_compressed_kv (HCA-style dense gather)
|
|
k_all, v_all = cache.gather_all_compressed_kv()
|
|
expected_entries = 4 * epb # 128 entries
|
|
assert k_all.shape == (1, expected_entries, hd), f"shape {k_all.shape}"
|
|
print(f" gather_all_compressed_kv: shape={k_all.shape}")
|
|
|
|
# Test gather_compressed_kv with top-k indices
|
|
top_k = 64
|
|
selected_indices = torch.randint(0, expected_entries, (T, top_k), dtype=torch.int64, device='cuda')
|
|
k_comp, v_comp = cache.gather_compressed_kv(selected_indices)
|
|
assert k_comp.shape == (1, top_k, hd), f"shape {k_comp.shape}"
|
|
print(f" gather_compressed_kv: shape={k_comp.shape}")
|
|
|
|
# Test gather_swa_kv
|
|
k_swa, v_swa = cache.gather_swa_kv()
|
|
assert k_swa.shape == (1, config.sliding_window, hd), f"shape {k_swa.shape}"
|
|
print(f" gather_swa_kv: shape={k_swa.shape}")
|
|
|
|
# Test sparse FMHA: concat [compressed, SWA] and run
|
|
from dsv4.kernels.attention.production import dsv4_attention
|
|
|
|
n_h = config.num_query_heads # 64
|
|
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# Concat: [compressed, swa] — single softmax (D5c insight)
|
|
k_full = torch.cat([k_comp, k_swa], dim=1) # (1, top_k + swa_len, hd)
|
|
v_full = torch.cat([v_comp, v_swa], dim=1)
|
|
|
|
n_comp = k_comp.shape[1] # 64
|
|
swa_len = config.sliding_window # 128
|
|
|
|
output = dsv4_attention(
|
|
q, k_full, v_full,
|
|
scale=1.0 / math.sqrt(hd),
|
|
swa_len=swa_len,
|
|
is_causal=True,
|
|
n_comp=n_comp,
|
|
)
|
|
assert output.shape == (n_h, T, hd), f"shape {output.shape}"
|
|
assert output.dtype == torch.bfloat16
|
|
# Check not NaN/Inf
|
|
assert torch.isfinite(output.float()).all(), "Output has NaN/Inf"
|
|
print(f" sparse FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}")
|
|
|
|
|
|
def test_hca_gather_and_fmha():
|
|
"""Test HCA dense gather + FMHA."""
|
|
from dsv4.cache.schema import build_schema
|
|
from dsv4.cache.state_cache import StateCachePool
|
|
from dsv4.cache.paged_cache import PagedKVPool
|
|
from dsv4.cache.handle import LayerCacheHandle
|
|
from dsv4.model.config import DSV4Config
|
|
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
|
|
|
|
config = DSV4Config.pro()
|
|
# Layer 0 of Pro is HCA
|
|
spec = LayerSpec(layer_idx=0, attn=AttentionType.HCA, ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
|
schema = build_schema(config, spec)
|
|
|
|
hd = config.head_dim # 512
|
|
rd = config.rope_dim # 64
|
|
fp8_dim = hd - rd
|
|
epb = schema.entries_per_block # 1 (HCA: 128/128=1 entry per block)
|
|
|
|
num_blocks = 32
|
|
state = StateCachePool(schema, max_requests=2, device='cuda')
|
|
paged = PagedKVPool(schema, num_blocks=num_blocks, device='cuda')
|
|
|
|
# Fill blocks with synthetic data
|
|
for b in range(8):
|
|
for e in range(epb):
|
|
paged.entries_fp8[b, e] = (torch.randn(fp8_dim) * 20).abs().clamp(0, 255).to(torch.uint8)
|
|
paged.entries_rope[b, e] = torch.randn(rd, dtype=torch.bfloat16)
|
|
paged.inv_scale[b, e] = 0.01 + torch.rand(1).item() * 0.1
|
|
|
|
block_table = torch.tensor([[0,1,2,3,4,5,6,7] + [-1]*24], dtype=torch.int32, device='cuda')
|
|
block_lens = torch.tensor([8], dtype=torch.int32, device='cuda')
|
|
|
|
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
positions = torch.tensor([1024], dtype=torch.int32, device='cuda')
|
|
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
|
|
cache = LayerCacheHandle(
|
|
paged=paged, state=state, schema=schema,
|
|
request_slots=slots, positions=positions, request_ids=request_ids,
|
|
block_table=block_table, block_lens=block_lens,
|
|
)
|
|
cache.num_query_heads = config.num_query_heads
|
|
|
|
# Write SWA
|
|
raw_kv = torch.randn(1, hd, dtype=torch.bfloat16, device='cuda')
|
|
cache.write_swa(raw_kv)
|
|
|
|
# Gather all compressed (HCA = dense)
|
|
k_comp, v_comp = cache.gather_all_compressed_kv()
|
|
k_swa, v_swa = cache.gather_swa_kv()
|
|
|
|
from dsv4.kernels.attention.production import dsv4_attention
|
|
|
|
n_h = config.num_query_heads # 128
|
|
q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
|
|
k_full = torch.cat([k_comp, k_swa], dim=1)
|
|
v_full = torch.cat([v_comp, v_swa], dim=1)
|
|
|
|
output = dsv4_attention(
|
|
q, k_full, v_full,
|
|
scale=1.0 / math.sqrt(hd),
|
|
swa_len=config.sliding_window,
|
|
is_causal=True,
|
|
n_comp=k_comp.shape[1],
|
|
)
|
|
assert output.shape == (n_h, 1, hd)
|
|
assert torch.isfinite(output.float()).all()
|
|
print(f" HCA FMHA: shape={output.shape}, max_abs={output.float().abs().max().item():.4f}")
|
|
|
|
|
|
def test():
|
|
print("=" * 60)
|
|
print("E2: CSA/HCA Gather + FMHA Integration")
|
|
print("=" * 60)
|
|
|
|
print("\n--- CSA: gather compressed + SWA, sparse FMHA ---")
|
|
test_csa_gather_and_fmha()
|
|
|
|
print("\n--- HCA: gather all compressed + SWA, dense FMHA ---")
|
|
test_hca_gather_and_fmha()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("E2 CSA/HCA INTEGRATION TEST PASSED")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|