Files
nvfp4-megamoe-kernel/dsv4/cache/prepare_forward.py
biondizzle 0f539e4855 Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation
Schema fix (paper eq.11-12):
  CSA needs m entries for current a-stream AND m entries for previous
  b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush,
  current a-stream becomes next flush b-stream input.
  HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream).
  tail_zb initialized to -1e9 so softmax naturally masks b-stream on
  first flush (paper: Z^b padded with -inf, C^b with zeros).

prepare_forward.py:
  Runs between captured graphs. Computes new compressed entries from
  position delta, pre-allocates blocks before the graph runs.
  Deterministic: entries_after - entries_before, ceil to block boundary.
  No allocation inside the captured graph.

flush_write.cu — 4 kernels:
  flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed
    entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale).
    One block per request, 128 threads. Amax reduction -> inv_scale.
  flush_write_hca_kernel: same minus indexer (no FP4 write).
  csa_rotate_state_kernel: after CSA flush, rotate a->b stream,
    clear a-stream, reset tail_len.
  hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len.

flush.py: Python orchestration.
  maybe_flush_csa/hca: always runs, kernels gate via valid_mask.
  Compressor produces entry, flush kernel quantize-scatters, state
  kernel rotates/resets. No host-side branching for cudagraph.

All tests pass on B200:
  Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0
  State: tail_zb initialized to -1e9, reset_slot preserves it
  prepare_forward: correct block allocation for position transitions
  HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op
  CSA flush write: RoPE exact, indexer FP4 keys written
  CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0
  HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00

84 lines
3.5 KiB
Python

"""Pre-forward block allocation.
Runs between captured graphs. Computes how many new compressed entries
will be produced by this forward (deterministic from positions), allocates
the required physical blocks, and updates block tables.
After this runs, the captured graph can perform flushes by writing to
already-resolved (request, layer, logical_block) -> physical_block
mappings. No allocation inside the graph.
"""
from __future__ import annotations
from typing import List
import torch
from dsv4.model.layer_schedule import LayerSpec, AttentionType
from dsv4.cache.manager import KVCacheManager
def prepare_forward(
manager: KVCacheManager,
request_slots: torch.Tensor, # [B] state cache slots
positions_before: torch.Tensor, # [B] absolute position BEFORE this forward
positions_after: torch.Tensor, # [B] absolute position AFTER this forward
) -> None:
"""Pre-allocate any blocks that will be needed by flushes in this forward.
Pure CPU/GPU bookkeeping — runs between captures, not in hot path.
For each compressed layer, works out how many flushes happen per
request and allocates blocks to cover them.
"""
for layer_idx, spec in enumerate(manager.schedule):
if spec.attn == AttentionType.SWA:
continue # No classical pool, no flushes.
schema = manager.schemas[layer_idx]
alloc = manager.allocators[layer_idx]
if alloc is None:
continue
m = (manager.config.csa_compression_ratio
if spec.attn == AttentionType.CSA
else manager.config.hca_compression_ratio)
epb = schema.entries_per_block
# How many compressed entries are NEWLY produced per request?
# = floor(positions_after / m) - floor(positions_before / m)
entries_after = (positions_after // m).to(torch.int64)
entries_before = (positions_before // m).to(torch.int64)
new_entries = entries_after - entries_before # [B] int64
# For each request, figure out how many new blocks are needed.
# A block holds `epb` entries. If there are already some entries
# in the current (open) block, they take some slots.
for b in range(request_slots.numel()):
n_new = int(new_entries[b])
if n_new == 0:
continue
req_slot = int(request_slots[b])
# How many entries are already in the current open block?
existing_blocks = int(manager.block_lens[layer_idx][req_slot])
entries_in_open_block = int(entries_before[b]) % epb if existing_blocks > 0 else 0
slots_remaining_in_open = epb - entries_in_open_block if entries_in_open_block > 0 else 0
# How many new blocks do we need?
if entries_in_open_block == 0 and existing_blocks == 0:
# Fresh — no open block yet
blocks_needed = (n_new + epb - 1) // epb
elif slots_remaining_in_open >= n_new:
# Fits in the current open block
blocks_needed = 0
else:
# Need additional blocks beyond the current open one
overflow = n_new - slots_remaining_in_open
blocks_needed = (overflow + epb - 1) // epb
if blocks_needed == 0:
continue
ids = alloc.acquire(blocks_needed)
existing = int(manager.block_lens[layer_idx][req_slot])
manager.block_tables[layer_idx][req_slot, existing:existing + blocks_needed] = ids
manager.block_lens[layer_idx][req_slot] = existing + blocks_needed