- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
- model/{dsv4,mtp,layer,layer_schedule}
- layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
- cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
- kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
- ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
- reference/{attention,compressor,csa_attention,moe_pipeline}
- kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
- test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
"""Storage for one layer's classical paged KV cache.
|
|
|
|
Layout per block:
|
|
entries: [num_blocks, entries_per_block, head_dim - rope_dim] FP8 (uint8 view)
|
|
entries_r: [num_blocks, entries_per_block, rope_dim] BF16
|
|
inv_scale: [num_blocks, entries_per_block] FP32
|
|
|
|
The FP8/BF16 split mirrors paper §2.3.4 ("BF16 for RoPE dims, FP8 for
|
|
the rest"). The kernel reads both halves and concatenates in registers.
|
|
|
|
For CSA layers, a parallel pool stores indexer keys at the same block
|
|
granularity — same block ID maps to a block in both pools.
|
|
"""
|
|
from __future__ import annotations
|
|
from typing import Optional
|
|
import torch
|
|
|
|
from dsv4.cache.schema import LayerCacheSchema
|
|
|
|
|
|
class PagedKVPool:
|
|
"""Per-layer classical paged KV storage.
|
|
|
|
Indexed by [physical_block_id, slot_in_block, ...].
|
|
Both compressed entries and indexer keys (if applicable) are
|
|
indexed by the SAME physical_block_id so a CSA layer's two pools
|
|
share the block table.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
schema: LayerCacheSchema,
|
|
num_blocks: int,
|
|
device: str = "cuda",
|
|
):
|
|
self.schema = schema
|
|
self.num_blocks = num_blocks
|
|
self.device = device
|
|
|
|
nb = num_blocks
|
|
epb = schema.entries_per_block
|
|
hd = schema.entry_head_dim
|
|
rd = schema.rope_dim
|
|
fp8_dim = hd - rd
|
|
|
|
# ---- Compressed entries ----
|
|
# FP8 stored as uint8 (we view as float8_e4m3fn at read time).
|
|
self.entries_fp8 = torch.zeros(
|
|
(nb, epb, fp8_dim), dtype=torch.uint8, device=device,
|
|
)
|
|
# BF16 RoPE'd half — no quantization.
|
|
self.entries_rope = torch.zeros(
|
|
(nb, epb, rd), dtype=torch.bfloat16, device=device,
|
|
)
|
|
# Per-entry inverse scale (for FP8 dequant in attention kernel).
|
|
self.inv_scale = torch.ones(
|
|
(nb, epb), dtype=torch.float32, device=device,
|
|
)
|
|
|
|
# ---- Indexer keys (CSA only) ----
|
|
if schema.indexer_entries_per_block > 0:
|
|
i_epb = schema.indexer_entries_per_block
|
|
i_hd = schema.indexer_head_dim
|
|
# Indexer QK is FP4 per paper §2.3.4 — but we store the keys
|
|
# post-quant. uint8 = 2 FP4 packed per byte.
|
|
self.indexer_keys_fp4 = torch.zeros(
|
|
(nb, i_epb, i_hd // 2), dtype=torch.uint8, device=device,
|
|
)
|
|
# Per-block-vector scale for the FP4 (one E4M3 scalar per
|
|
# 16-element group, per the NVFP4 quantization scheme).
|
|
self.indexer_scale = torch.ones(
|
|
(nb, i_epb, i_hd // 16),
|
|
dtype=torch.float8_e4m3fn, device=device,
|
|
)
|
|
self.indexer_global_scale = torch.ones(
|
|
(nb,), dtype=torch.float32, device=device,
|
|
)
|
|
else:
|
|
self.indexer_keys_fp4 = None
|
|
self.indexer_scale = None
|
|
self.indexer_global_scale = None
|
|
|
|
def memory_bytes(self) -> int:
|
|
"""Total GPU memory used by this pool."""
|
|
total = 0
|
|
for name in ("entries_fp8", "entries_rope", "inv_scale",
|
|
"indexer_keys_fp4", "indexer_scale", "indexer_global_scale"):
|
|
t = getattr(self, name)
|
|
if t is not None:
|
|
total += t.numel() * t.element_size()
|
|
return total
|