Files
nvfp4-megamoe-kernel/dsv4/cache/state_cache.py
biondizzle 0f539e4855 Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation
Schema fix (paper eq.11-12):
  CSA needs m entries for current a-stream AND m entries for previous
  b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush,
  current a-stream becomes next flush b-stream input.
  HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream).
  tail_zb initialized to -1e9 so softmax naturally masks b-stream on
  first flush (paper: Z^b padded with -inf, C^b with zeros).

prepare_forward.py:
  Runs between captured graphs. Computes new compressed entries from
  position delta, pre-allocates blocks before the graph runs.
  Deterministic: entries_after - entries_before, ceil to block boundary.
  No allocation inside the captured graph.

flush_write.cu — 4 kernels:
  flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed
    entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale).
    One block per request, 128 threads. Amax reduction -> inv_scale.
  flush_write_hca_kernel: same minus indexer (no FP4 write).
  csa_rotate_state_kernel: after CSA flush, rotate a->b stream,
    clear a-stream, reset tail_len.
  hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len.

flush.py: Python orchestration.
  maybe_flush_csa/hca: always runs, kernels gate via valid_mask.
  Compressor produces entry, flush kernel quantize-scatters, state
  kernel rotates/resets. No host-side branching for cudagraph.

All tests pass on B200:
  Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0
  State: tail_zb initialized to -1e9, reset_slot preserves it
  prepare_forward: correct block allocation for position transitions
  HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op
  CSA flush write: RoPE exact, indexer FP4 keys written
  CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0
  HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00

4.3 KiB