Files
biondizzle f3b551956d Cleanup Step 2: Archive Lineage P code, fix broken imports
- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
  - model/{dsv4,mtp,layer,layer_schedule}
  - layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
  - cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
  - kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
  - ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
  - reference/{attention,compressor,csa_attention,moe_pipeline}
  - kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
  - test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
2026-06-02 19:27:07 +00:00

163 lines
5.6 KiB
Python

"""In-graph flush orchestration.
Called when tail_len crosses the compression threshold. The actual
compression math is in the csa_hca_compressor kernel; this module
handles the quantize-scatter-write step and the state rotation.
The maybe_flush_* functions always run when their attention type
matches — no host-side `if tail_full` check. The kernels gate
internally via `valid_mask` computed from `tail_len`. This keeps
the call sequence identical across forward passes for cudagraph.
"""
from __future__ import annotations
from typing import Optional
import os
import torch
from torch.utils.cpp_extension import load
from dsv4.cache.schema import LayerCacheSchema, AttentionType
_flush_mod = None
def _get_flush_module():
global _flush_mod
if _flush_mod is not None:
return _flush_mod
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "kernels", "cuda")
_flush_mod = load(
name="flush_write",
sources=[os.path.join(kernel_dir, "flush_write.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _flush_mod
def maybe_flush_csa(
handle,
schema: LayerCacheSchema,
m: int,
) -> None:
"""For CSA: emit compressed entries for requests whose tail is full.
Steps:
1. Determine which requests have tail_len >= m (valid_mask).
2. Run the CSA compressor on tail buffers.
3. Scatter compressed entry + indexer key into paged pool.
4. Rotate a-stream -> b-stream, clear a-stream.
"""
from dsv4.kernels.compressor import csa_compress_tail
state = handle.state
paged = handle.paged
mod = _get_flush_module()
# Step 1: valid_mask — which requests have a full tail buffer.
# tail_len is [max_requests], request_slots is [B].
tail_lens = state.tail_len[handle.request_slots] # [B]
valid_mask = tail_lens >= m # [B] bool
# If no requests need flushing, short-circuit.
if not valid_mask.any().item():
return
# Step 2: compress the tail.
# The compressor kernel takes the tail buffers and produces
# one compressed entry per request (for those where valid_mask=True).
entry, indexer_key = csa_compress_tail(
tail_ka=state.tail_ka,
tail_za=state.tail_za,
tail_kb=state.tail_kb,
tail_zb=state.tail_zb,
tail_len=state.tail_len,
request_slots=handle.request_slots,
m=m,
)
# entry: [B, head_dim] BF16
# indexer_key: [B, indexer_head_dim] BF16
# Step 3: scatter into the paged pool.
# The flush position for each request = the position of the last
# token in the tail (positions before this forward minus 1 would
# be the wrong reference; we need the tail's last position).
# For the block table lookup, we use the compressed entry index
# derived from positions.
# Use the positions of the requests' current tokens to figure
# out which entry slot to write into.
flush_positions = handle.positions # [tokens] -> need per-request
# For now, derive entry index from the per-request state:
# compressed_entry_idx = sum of all flushes so far for this request.
# This is (positions_of_last_appended_token) // m
# Simplification: use request_slots to look up per-request position.
# The handle's positions are per-token, not per-request.
# We need one position per request = position of the last appended token.
# For a single-token decode, that's just positions[-1] per request.
# For a general case, take the max position per request.
# This is computed by the append kernel (stored in tail_len and the
# actual positions in the tail). For now, use handle.positions
# and scatter by request.
# The kernel resolves slot_in_block from positions internally.
mod.flush_write_csa(
entry, indexer_key, valid_mask, handle.request_slots,
handle.positions[:handle.request_slots.shape[0]], # one pos per request
handle.block_table,
paged.entries_fp8, paged.entries_rope, paged.inv_scale,
paged.indexer_keys_fp4, paged.indexer_scale,
schema.entries_per_block, m, schema.rope_dim,
schema.entry_head_dim, schema.indexer_head_dim,
)
# Step 4: rotate state — current a-stream becomes next b-stream.
mod.csa_rotate_state(
valid_mask, handle.request_slots,
state.tail_ka, state.tail_za, state.tail_kb, state.tail_zb,
state.tail_len, m, schema.entry_head_dim,
)
def maybe_flush_hca(
handle,
schema: LayerCacheSchema,
m_prime: int,
) -> None:
"""For HCA: emit one entry per request whose tail_len >= m'."""
from dsv4.kernels.compressor import hca_compress_tail
state = handle.state
paged = handle.paged
mod = _get_flush_module()
tail_lens = state.tail_len[handle.request_slots]
valid_mask = tail_lens >= m_prime
if not valid_mask.any().item():
return
entry = hca_compress_tail(
tail_ka=state.tail_ka,
tail_za=state.tail_za,
tail_len=state.tail_len,
request_slots=handle.request_slots,
m=m_prime,
)
# entry: [B, head_dim] BF16
mod.flush_write_hca(
entry, valid_mask, handle.request_slots,
handle.positions[:handle.request_slots.shape[0]],
handle.block_table,
paged.entries_fp8, paged.entries_rope, paged.inv_scale,
schema.entries_per_block, m_prime, schema.rope_dim,
schema.entry_head_dim,
)
# Reset tail — no b-stream rotation for HCA.
mod.hca_reset_state(
valid_mask, handle.request_slots,
state.tail_ka, state.tail_za, state.tail_len,
m_prime, schema.entry_head_dim,
)