Schema fix (paper eq.11-12):
CSA needs m entries for current a-stream AND m entries for previous
b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush,
current a-stream becomes next flush b-stream input.
HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream).
tail_zb initialized to -1e9 so softmax naturally masks b-stream on
first flush (paper: Z^b padded with -inf, C^b with zeros).
prepare_forward.py:
Runs between captured graphs. Computes new compressed entries from
position delta, pre-allocates blocks before the graph runs.
Deterministic: entries_after - entries_before, ceil to block boundary.
No allocation inside the captured graph.
flush_write.cu — 4 kernels:
flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed
entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale).
One block per request, 128 threads. Amax reduction -> inv_scale.
flush_write_hca_kernel: same minus indexer (no FP4 write).
csa_rotate_state_kernel: after CSA flush, rotate a->b stream,
clear a-stream, reset tail_len.
hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len.
flush.py: Python orchestration.
maybe_flush_csa/hca: always runs, kernels gate via valid_mask.
Compressor produces entry, flush kernel quantize-scatters, state
kernel rotates/resets. No host-side branching for cudagraph.
All tests pass on B200:
Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0
State: tail_zb initialized to -1e9, reset_slot preserves it
prepare_forward: correct block allocation for position transitions
HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op
CSA flush write: RoPE exact, indexer FP4 keys written
CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0
HCA state reset: ka/za zeroed, tail_len=0
84 lines
3.5 KiB
Python
84 lines
3.5 KiB
Python
"""Pre-forward block allocation.
|
|
|
|
Runs between captured graphs. Computes how many new compressed entries
|
|
will be produced by this forward (deterministic from positions), allocates
|
|
the required physical blocks, and updates block tables.
|
|
|
|
After this runs, the captured graph can perform flushes by writing to
|
|
already-resolved (request, layer, logical_block) -> physical_block
|
|
mappings. No allocation inside the graph.
|
|
"""
|
|
from __future__ import annotations
|
|
from typing import List
|
|
import torch
|
|
|
|
from dsv4.model.layer_schedule import LayerSpec, AttentionType
|
|
from dsv4.cache.manager import KVCacheManager
|
|
|
|
|
|
def prepare_forward(
|
|
manager: KVCacheManager,
|
|
request_slots: torch.Tensor, # [B] state cache slots
|
|
positions_before: torch.Tensor, # [B] absolute position BEFORE this forward
|
|
positions_after: torch.Tensor, # [B] absolute position AFTER this forward
|
|
) -> None:
|
|
"""Pre-allocate any blocks that will be needed by flushes in this forward.
|
|
|
|
Pure CPU/GPU bookkeeping — runs between captures, not in hot path.
|
|
For each compressed layer, works out how many flushes happen per
|
|
request and allocates blocks to cover them.
|
|
"""
|
|
for layer_idx, spec in enumerate(manager.schedule):
|
|
if spec.attn == AttentionType.SWA:
|
|
continue # No classical pool, no flushes.
|
|
|
|
schema = manager.schemas[layer_idx]
|
|
alloc = manager.allocators[layer_idx]
|
|
if alloc is None:
|
|
continue
|
|
|
|
m = (manager.config.csa_compression_ratio
|
|
if spec.attn == AttentionType.CSA
|
|
else manager.config.hca_compression_ratio)
|
|
epb = schema.entries_per_block
|
|
|
|
# How many compressed entries are NEWLY produced per request?
|
|
# = floor(positions_after / m) - floor(positions_before / m)
|
|
entries_after = (positions_after // m).to(torch.int64)
|
|
entries_before = (positions_before // m).to(torch.int64)
|
|
new_entries = entries_after - entries_before # [B] int64
|
|
|
|
# For each request, figure out how many new blocks are needed.
|
|
# A block holds `epb` entries. If there are already some entries
|
|
# in the current (open) block, they take some slots.
|
|
for b in range(request_slots.numel()):
|
|
n_new = int(new_entries[b])
|
|
if n_new == 0:
|
|
continue
|
|
req_slot = int(request_slots[b])
|
|
|
|
# How many entries are already in the current open block?
|
|
existing_blocks = int(manager.block_lens[layer_idx][req_slot])
|
|
entries_in_open_block = int(entries_before[b]) % epb if existing_blocks > 0 else 0
|
|
slots_remaining_in_open = epb - entries_in_open_block if entries_in_open_block > 0 else 0
|
|
|
|
# How many new blocks do we need?
|
|
if entries_in_open_block == 0 and existing_blocks == 0:
|
|
# Fresh — no open block yet
|
|
blocks_needed = (n_new + epb - 1) // epb
|
|
elif slots_remaining_in_open >= n_new:
|
|
# Fits in the current open block
|
|
blocks_needed = 0
|
|
else:
|
|
# Need additional blocks beyond the current open one
|
|
overflow = n_new - slots_remaining_in_open
|
|
blocks_needed = (overflow + epb - 1) // epb
|
|
|
|
if blocks_needed == 0:
|
|
continue
|
|
|
|
ids = alloc.acquire(blocks_needed)
|
|
existing = int(manager.block_lens[layer_idx][req_slot])
|
|
manager.block_tables[layer_idx][req_slot, existing:existing + blocks_needed] = ids
|
|
manager.block_lens[layer_idx][req_slot] = existing + blocks_needed
|