Files
biondizzle f3b551956d Cleanup Step 2: Archive Lineage P code, fix broken imports
- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
  - model/{dsv4,mtp,layer,layer_schedule}
  - layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
  - cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
  - kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
  - ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
  - reference/{attention,compressor,csa_attention,moe_pipeline}
  - kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
  - test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
2026-06-02 19:27:07 +00:00

116 lines
3.8 KiB
Python

"""Python wrappers for cache gather kernels.
Provides the bridge between LayerCacheHandle and the raw CUDA gather ops.
Loaded via torch.utils.cpp_extension (JIT, sm_100a).
Three gather paths:
1. gather_compressed_kv: CSA — top-k selected compressed entries
2. gather_all_compressed_kv: HCA — all compressed entries (no indexer)
3. gather_swa_kv: SWA window entries from state cache
All outputs are dense BF16 tensors ready for FMHA consumption.
"""
import os
import torch
from torch.utils.cpp_extension import load
_kernel_module = None
_kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
def _get_kernel_module():
global _kernel_module
if _kernel_module is not None:
return _kernel_module
_kernel_module = load(
name="cache_gather",
sources=[
os.path.join(_kernel_dir, "gather_kv.cu"),
os.path.join(_kernel_dir, "gather_swa.cu"),
],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def gather_compressed_kv(
entries_fp8: torch.Tensor, # [num_blocks, epb, fp8_dim] uint8
entries_rope: torch.Tensor, # [num_blocks, epb, rope_dim] BF16
inv_scale: torch.Tensor, # [num_blocks, epb] FP32
topk_indices: torch.Tensor, # [T, top_k] int32
block_table: torch.Tensor, # [batch, max_logical_blocks] int32
entries_per_block: int,
head_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""Gather top-k compressed KV entries into a dense BF16 tile.
Returns: (T, top_k, head_dim) BF16
"""
T = topk_indices.size(0)
top_k = topk_indices.size(1)
output = torch.zeros(T, top_k, head_dim, dtype=torch.bfloat16, device=entries_fp8.device)
mod = _get_kernel_module()
mod.gather_kv(
entries_fp8, entries_rope, inv_scale,
topk_indices, block_table, output,
entries_per_block, rope_dim,
)
return output
def gather_all_compressed_kv(
entries_fp8: torch.Tensor, # [num_blocks, epb, fp8_dim] uint8
entries_rope: torch.Tensor, # [num_blocks, epb, rope_dim] BF16
inv_scale: torch.Tensor, # [num_blocks, epb] FP32
block_table: torch.Tensor, # [batch, max_logical_blocks] int32
block_lens: torch.Tensor, # [batch] int32
entries_per_block: int,
head_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""Gather ALL compressed KV entries for HCA (dense attention).
Returns: (batch, total_entries, head_dim) BF16 where
total_entries = block_lens.sum() * entries_per_block (padded to max)
"""
batch = block_table.size(0)
max_blocks = block_lens.max().item()
total_entries = max_blocks * entries_per_block
output = torch.zeros(batch, total_entries, head_dim, dtype=torch.bfloat16, device=entries_fp8.device)
mod = _get_kernel_module()
mod.gather_all_compressed(
entries_fp8, entries_rope, inv_scale,
block_table, block_lens, output,
entries_per_block, rope_dim,
)
return output
def gather_swa_kv(
swa_fp8: torch.Tensor, # [max_req, n_win, fp8_dim] uint8
swa_rope: torch.Tensor, # [max_req, n_win, rope_dim] BF16
swa_inv: torch.Tensor, # [max_req, n_win] FP32
swa_pos: torch.Tensor, # [max_req, n_win] int32
request_slots: torch.Tensor, # [batch] int32
head_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""Gather SWA window entries into a dense BF16 tile.
Returns: (batch, n_win, head_dim) BF16
"""
batch = request_slots.size(0)
n_win = swa_fp8.size(1)
output = torch.zeros(batch, n_win, head_dim, dtype=torch.bfloat16, device=swa_fp8.device)
mod = _get_kernel_module()
mod.gather_swa(
swa_fp8, swa_rope, swa_inv, swa_pos, request_slots,
output, head_dim, rope_dim,
)
return output