- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
- model/{dsv4,mtp,layer,layer_schedule}
- layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
- cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
- kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
- ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
- reference/{attention,compressor,csa_attention,moe_pipeline}
- kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
- test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
129 lines
4.9 KiB
Python
129 lines
4.9 KiB
Python
"""E2 smoke test: one full DSV4 layer with SWA attention.
|
|
|
|
Tests the integration path:
|
|
X → mHC.pre → RMSNorm → SWA attention → mHC.post → RMSNorm → FFN
|
|
|
|
SWA is the simplest path (no compressor, no indexer), so it's the
|
|
first integration checkpoint. Layer 0 of the Flash variant uses SWA.
|
|
|
|
This test verifies:
|
|
1. LayerCacheHandle construction + gather_swa_kv
|
|
2. AttentionSubBlock._forward_swa
|
|
3. TransformerLayer.forward
|
|
4. Shape/dtype correctness
|
|
|
|
Numerical verification against a pure PyTorch reference is deferred
|
|
to when the compressor + indexer are wired (CSA layer test).
|
|
"""
|
|
import torch
|
|
import math
|
|
|
|
|
|
def test_swa_cache_gather():
|
|
"""Test gather_swa_kv: write KV to cache, read it back as dense BF16."""
|
|
from dsv4.cache.schema import LayerCacheSchema, build_schema
|
|
from dsv4.cache.state_cache import StateCachePool
|
|
from dsv4.cache.handle import LayerCacheHandle
|
|
from dsv4.model.config import DSV4Config
|
|
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
|
|
|
|
config = DSV4Config.flash()
|
|
# Layer 0 of Flash is SWA
|
|
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA, ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
|
schema = build_schema(config, spec)
|
|
|
|
state = StateCachePool(schema, max_requests=2, device='cuda')
|
|
|
|
# Build handle
|
|
slot = 0
|
|
slots = torch.tensor([slot], dtype=torch.int32, device='cuda')
|
|
positions = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
|
|
handle = LayerCacheHandle(
|
|
paged=None, state=state, schema=schema,
|
|
request_slots=slots, positions=positions, request_ids=request_ids,
|
|
block_table=None, block_lens=None,
|
|
)
|
|
handle.num_query_heads = config.num_query_heads
|
|
|
|
# Write some KV to the cache
|
|
hd = config.head_dim
|
|
raw_kv = torch.randn(1, hd, dtype=torch.bfloat16, device='cuda')
|
|
handle.write_swa(raw_kv)
|
|
|
|
# Read it back
|
|
k_swa, v_swa = handle.gather_swa_kv()
|
|
assert k_swa.shape == (1, config.sliding_window, hd), f"shape {k_swa.shape}"
|
|
assert k_swa.dtype == torch.bfloat16
|
|
|
|
# The one written position should match (approximately — FP8 quantization loss)
|
|
# Position 0 in ring buffer should have the dequantized KV
|
|
# Check that something non-zero was written
|
|
assert k_swa[0, 0].abs().sum() > 0, "SWA position 0 is zero — write failed?"
|
|
print(f" gather_swa_kv: shape={k_swa.shape}, pos0_norm={k_swa[0, 0].float().norm():.4f}")
|
|
|
|
|
|
def test_swa_attention_forward():
|
|
"""Test SWA attention through the full AttentionSubBlock forward."""
|
|
from dsv4.model.config import DSV4Config
|
|
from dsv4.model.layer_schedule import LayerSpec, AttentionType, FFNType, RouterMode
|
|
from dsv4.layers.attention import AttentionSubBlock
|
|
from dsv4.cache.schema import build_schema
|
|
from dsv4.cache.state_cache import StateCachePool
|
|
from dsv4.cache.handle import LayerCacheHandle
|
|
|
|
config = DSV4Config.flash()
|
|
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA, ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
|
attn = AttentionSubBlock(config, spec)
|
|
|
|
schema = build_schema(config, spec)
|
|
state = StateCachePool(schema, max_requests=2, device='cuda')
|
|
|
|
slots = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
positions = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
request_ids = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
|
|
cache = LayerCacheHandle(
|
|
paged=None, state=state, schema=schema,
|
|
request_slots=slots, positions=positions, request_ids=request_ids,
|
|
block_table=None, block_lens=None,
|
|
)
|
|
cache.num_query_heads = config.num_query_heads
|
|
|
|
# Forward with synthetic input
|
|
T = 1
|
|
x = torch.randn(T, config.hidden_size, dtype=torch.bfloat16, device='cuda')
|
|
# NOTE: Nvfp4Linear forward needs weights — for smoke test, skip the
|
|
# full forward and test the gather+FMHA path directly.
|
|
# The full AttentionSubBlock.forward needs weight loading, which is
|
|
# deferred to the checkpoint loader integration.
|
|
|
|
# Test just the FMHA path with dense KV (bypassing projections)
|
|
from dsv4.kernels.attention import swa_only_fmha
|
|
|
|
# Write some KV first
|
|
raw_kv = torch.randn(T, config.head_dim, dtype=torch.bfloat16, device='cuda')
|
|
cache.write_swa(raw_kv)
|
|
|
|
# Gather and run FMHA
|
|
q = torch.randn(T, config.num_query_heads * config.head_dim, dtype=torch.bfloat16, device='cuda')
|
|
attn_out = swa_only_fmha(q, cache, sliding_window=config.sliding_window)
|
|
assert attn_out.shape == (T, config.num_query_heads * config.head_dim)
|
|
print(f" swa_only_fmha: shape={attn_out.shape}, dtype={attn_out.dtype}")
|
|
|
|
|
|
def test():
|
|
print("=" * 60)
|
|
print("E2: One Layer Smoke Test — SWA Path")
|
|
print("=" * 60)
|
|
test_swa_cache_gather()
|
|
test_swa_attention_forward()
|
|
print("\n" + "=" * 60)
|
|
print("E2 SWA SMOKE TEST PASSED")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|