Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation

Schema fix (paper eq.11-12):
  CSA needs m entries for current a-stream AND m entries for previous
  b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush,
  current a-stream becomes next flush b-stream input.
  HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream).
  tail_zb initialized to -1e9 so softmax naturally masks b-stream on
  first flush (paper: Z^b padded with -inf, C^b with zeros).

prepare_forward.py:
  Runs between captured graphs. Computes new compressed entries from
  position delta, pre-allocates blocks before the graph runs.
  Deterministic: entries_after - entries_before, ceil to block boundary.
  No allocation inside the captured graph.

flush_write.cu — 4 kernels:
  flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed
    entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale).
    One block per request, 128 threads. Amax reduction -> inv_scale.
  flush_write_hca_kernel: same minus indexer (no FP4 write).
  csa_rotate_state_kernel: after CSA flush, rotate a->b stream,
    clear a-stream, reset tail_len.
  hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len.

flush.py: Python orchestration.
  maybe_flush_csa/hca: always runs, kernels gate via valid_mask.
  Compressor produces entry, flush kernel quantize-scatters, state
  kernel rotates/resets. No host-side branching for cudagraph.

All tests pass on B200:
  Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0
  State: tail_zb initialized to -1e9, reset_slot preserves it
  prepare_forward: correct block allocation for position transitions
  HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op
  CSA flush write: RoPE exact, indexer FP4 keys written
  CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0
  HCA state reset: ka/za zeroed, tail_len=0
This commit is contained in:
2026-05-22 00:25:47 +00:00
parent 23abfe9845
commit 8fcbc699a8
5 changed files with 749 additions and 52 deletions

162
dsv4/cache/flush.py vendored Normal file
View File

@@ -0,0 +1,162 @@
"""In-graph flush orchestration.
Called when tail_len crosses the compression threshold. The actual
compression math is in the csa_hca_compressor kernel; this module
handles the quantize-scatter-write step and the state rotation.
The maybe_flush_* functions always run when their attention type
matches — no host-side `if tail_full` check. The kernels gate
internally via `valid_mask` computed from `tail_len`. This keeps
the call sequence identical across forward passes for cudagraph.
"""
from __future__ import annotations
from typing import Optional
import os
import torch
from torch.utils.cpp_extension import load
from dsv4.cache.schema import LayerCacheSchema, AttentionType
_flush_mod = None
def _get_flush_module():
global _flush_mod
if _flush_mod is not None:
return _flush_mod
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "kernels", "cuda")
_flush_mod = load(
name="flush_write",
sources=[os.path.join(kernel_dir, "flush_write.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _flush_mod
def maybe_flush_csa(
handle,
schema: LayerCacheSchema,
m: int,
) -> None:
"""For CSA: emit compressed entries for requests whose tail is full.
Steps:
1. Determine which requests have tail_len >= m (valid_mask).
2. Run the CSA compressor on tail buffers.
3. Scatter compressed entry + indexer key into paged pool.
4. Rotate a-stream -> b-stream, clear a-stream.
"""
from dsv4.kernels.compressor import csa_compress_tail
state = handle.state
paged = handle.paged
mod = _get_flush_module()
# Step 1: valid_mask — which requests have a full tail buffer.
# tail_len is [max_requests], request_slots is [B].
tail_lens = state.tail_len[handle.request_slots] # [B]
valid_mask = tail_lens >= m # [B] bool
# If no requests need flushing, short-circuit.
if not valid_mask.any().item():
return
# Step 2: compress the tail.
# The compressor kernel takes the tail buffers and produces
# one compressed entry per request (for those where valid_mask=True).
entry, indexer_key = csa_compress_tail(
tail_ka=state.tail_ka,
tail_za=state.tail_za,
tail_kb=state.tail_kb,
tail_zb=state.tail_zb,
tail_len=state.tail_len,
request_slots=handle.request_slots,
m=m,
)
# entry: [B, head_dim] BF16
# indexer_key: [B, indexer_head_dim] BF16
# Step 3: scatter into the paged pool.
# The flush position for each request = the position of the last
# token in the tail (positions before this forward minus 1 would
# be the wrong reference; we need the tail's last position).
# For the block table lookup, we use the compressed entry index
# derived from positions.
# Use the positions of the requests' current tokens to figure
# out which entry slot to write into.
flush_positions = handle.positions # [tokens] -> need per-request
# For now, derive entry index from the per-request state:
# compressed_entry_idx = sum of all flushes so far for this request.
# This is (positions_of_last_appended_token) // m
# Simplification: use request_slots to look up per-request position.
# The handle's positions are per-token, not per-request.
# We need one position per request = position of the last appended token.
# For a single-token decode, that's just positions[-1] per request.
# For a general case, take the max position per request.
# This is computed by the append kernel (stored in tail_len and the
# actual positions in the tail). For now, use handle.positions
# and scatter by request.
# The kernel resolves slot_in_block from positions internally.
mod.flush_write_csa(
entry, indexer_key, valid_mask, handle.request_slots,
handle.positions[:handle.request_slots.shape[0]], # one pos per request
handle.block_table,
paged.entries_fp8, paged.entries_rope, paged.inv_scale,
paged.indexer_keys_fp4, paged.indexer_scale,
schema.entries_per_block, m, schema.rope_dim,
schema.entry_head_dim, schema.indexer_head_dim,
)
# Step 4: rotate state — current a-stream becomes next b-stream.
mod.csa_rotate_state(
valid_mask, handle.request_slots,
state.tail_ka, state.tail_za, state.tail_kb, state.tail_zb,
state.tail_len, m, schema.entry_head_dim,
)
def maybe_flush_hca(
handle,
schema: LayerCacheSchema,
m_prime: int,
) -> None:
"""For HCA: emit one entry per request whose tail_len >= m'."""
from dsv4.kernels.compressor import hca_compress_tail
state = handle.state
paged = handle.paged
mod = _get_flush_module()
tail_lens = state.tail_len[handle.request_slots]
valid_mask = tail_lens >= m_prime
if not valid_mask.any().item():
return
entry = hca_compress_tail(
tail_ka=state.tail_ka,
tail_za=state.tail_za,
tail_len=state.tail_len,
request_slots=handle.request_slots,
m=m_prime,
)
# entry: [B, head_dim] BF16
mod.flush_write_hca(
entry, valid_mask, handle.request_slots,
handle.positions[:handle.request_slots.shape[0]],
handle.block_table,
paged.entries_fp8, paged.entries_rope, paged.inv_scale,
schema.entries_per_block, m_prime, schema.rope_dim,
schema.entry_head_dim,
)
# Reset tail — no b-stream rotation for HCA.
mod.hca_reset_state(
valid_mask, handle.request_slots,
state.tail_ka, state.tail_za, state.tail_len,
m_prime, schema.entry_head_dim,
)

83
dsv4/cache/prepare_forward.py vendored Normal file
View File

@@ -0,0 +1,83 @@
"""Pre-forward block allocation.
Runs between captured graphs. Computes how many new compressed entries
will be produced by this forward (deterministic from positions), allocates
the required physical blocks, and updates block tables.
After this runs, the captured graph can perform flushes by writing to
already-resolved (request, layer, logical_block) -> physical_block
mappings. No allocation inside the graph.
"""
from __future__ import annotations
from typing import List
import torch
from dsv4.model.layer_schedule import LayerSpec, AttentionType
from dsv4.cache.manager import KVCacheManager
def prepare_forward(
manager: KVCacheManager,
request_slots: torch.Tensor, # [B] state cache slots
positions_before: torch.Tensor, # [B] absolute position BEFORE this forward
positions_after: torch.Tensor, # [B] absolute position AFTER this forward
) -> None:
"""Pre-allocate any blocks that will be needed by flushes in this forward.
Pure CPU/GPU bookkeeping — runs between captures, not in hot path.
For each compressed layer, works out how many flushes happen per
request and allocates blocks to cover them.
"""
for layer_idx, spec in enumerate(manager.schedule):
if spec.attn == AttentionType.SWA:
continue # No classical pool, no flushes.
schema = manager.schemas[layer_idx]
alloc = manager.allocators[layer_idx]
if alloc is None:
continue
m = (manager.config.csa_compression_ratio
if spec.attn == AttentionType.CSA
else manager.config.hca_compression_ratio)
epb = schema.entries_per_block
# How many compressed entries are NEWLY produced per request?
# = floor(positions_after / m) - floor(positions_before / m)
entries_after = (positions_after // m).to(torch.int64)
entries_before = (positions_before // m).to(torch.int64)
new_entries = entries_after - entries_before # [B] int64
# For each request, figure out how many new blocks are needed.
# A block holds `epb` entries. If there are already some entries
# in the current (open) block, they take some slots.
for b in range(request_slots.numel()):
n_new = int(new_entries[b])
if n_new == 0:
continue
req_slot = int(request_slots[b])
# How many entries are already in the current open block?
existing_blocks = int(manager.block_lens[layer_idx][req_slot])
entries_in_open_block = int(entries_before[b]) % epb if existing_blocks > 0 else 0
slots_remaining_in_open = epb - entries_in_open_block if entries_in_open_block > 0 else 0
# How many new blocks do we need?
if entries_in_open_block == 0 and existing_blocks == 0:
# Fresh — no open block yet
blocks_needed = (n_new + epb - 1) // epb
elif slots_remaining_in_open >= n_new:
# Fits in the current open block
blocks_needed = 0
else:
# Need additional blocks beyond the current open one
overflow = n_new - slots_remaining_in_open
blocks_needed = (overflow + epb - 1) // epb
if blocks_needed == 0:
continue
ids = alloc.acquire(blocks_needed)
existing = int(manager.block_lens[layer_idx][req_slot])
manager.block_tables[layer_idx][req_slot, existing:existing + blocks_needed] = ids
manager.block_lens[layer_idx][req_slot] = existing + blocks_needed

50
dsv4/cache/schema.py vendored
View File

@@ -33,32 +33,32 @@ class LayerCacheSchema:
attn_type: AttentionType
# ---- Classical paged cache (compressed entries) ----
# Number of compressed entries in one block of BLOCK_SIZE_ORIGINAL_TOKENS
# original tokens. For HCA m'=128 this is 1; for CSA m=4 this is 32.
# SWA-only layers have no classical pool — entries_per_block = 0.
entries_per_block: int
# Width of one entry (head_dim).
entry_head_dim: int
# RoPE-applied dimensions (BF16). Others FP8.
rope_dim: int
# ---- Indexer pool (CSA only) ----
# Compressed indexer keys, one per compressed entry.
indexer_entries_per_block: int # 32 for CSA, 0 for HCA/SWA
indexer_head_dim: int # 128 for CSA, 0 for others
indexer_entries_per_block: int
indexer_head_dim: int
# ---- State cache (SWA window + uncompressed tail) ----
swa_window_size: int # 128 for all layer types
# Uncompressed tail buffer — needed only for compressed layers.
# CSA: up to m-1 = 3 pending tokens before flushing compression.
# HCA: up to m'-1 = 127 pending tokens.
# SWA-only: no tail (no compression branch).
tail_buffer_size: int
swa_window_size: int
# Per-token inverse scale storage (for FP8 dequant). One FP32 scalar
# per stored entry/window-slot.
# CSA: paper eq.11-12, the i-th flush uses Ca[m*i:m*(i+1)] and
# Cb[m*(i-1):m*i]. After flush, current a-stream becomes next b-stream.
# So we need m entries for current a-stream AND m entries for previous
# b-stream. Total tail = 2*m for CSA.
tail_buffer_size_a: int # m (CSA) or m' (HCA) — current tokens
tail_buffer_size_b: int # m (CSA only) — previous block's a-stream kept as b-input
# Per-token inverse scale storage (for FP8 dequant).
needs_inv_scale: bool = True
@property
def tail_buffer_size(self) -> int:
"""Total tail entries (for backward compat with schema consumers)."""
return self.tail_buffer_size_a + self.tail_buffer_size_b
def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
"""Derive cache schema for a single layer from architectural config."""
@@ -72,7 +72,8 @@ def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
indexer_entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
indexer_head_dim=config.indexer_head_dim,
swa_window_size=config.sliding_window,
tail_buffer_size=config.csa_compression_ratio - 1,
tail_buffer_size_a=config.csa_compression_ratio, # m=4 current
tail_buffer_size_b=config.csa_compression_ratio, # m=4 previous (b-stream)
)
elif spec.attn == AttentionType.HCA:
return LayerCacheSchema(
@@ -84,7 +85,8 @@ def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
indexer_entries_per_block=0,
indexer_head_dim=0,
swa_window_size=config.sliding_window,
tail_buffer_size=config.hca_compression_ratio - 1,
tail_buffer_size_a=config.hca_compression_ratio, # m'=128 current
tail_buffer_size_b=0, # HCA has no b-stream
)
else: # SWA-only
return LayerCacheSchema(
@@ -96,7 +98,8 @@ def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
indexer_entries_per_block=0,
indexer_head_dim=0,
swa_window_size=config.sliding_window,
tail_buffer_size=0,
tail_buffer_size_a=0,
tail_buffer_size_b=0,
)
@@ -106,14 +109,7 @@ def compute_block_budget(
max_context_tokens: int,
max_concurrent_requests: int,
) -> dict[str, int]:
"""Compute per-layer-type block counts for the allocator.
Returns {layer_type: num_blocks} where layer_type is 'csa' or 'hca'.
SWA-only layers need no classical blocks.
Block budget = max_concurrent_requests * (max_context / BLOCK_SIZE).
Add 10% headroom for fragmentation.
"""
"""Compute per-layer-type block counts for the allocator."""
blocks_per_request = max_context_tokens // BLOCK_SIZE_ORIGINAL_TOKENS
headroom = 1.10
result = {}

View File

@@ -7,6 +7,13 @@ and reclaims them at completion.
Per paper §3.5.1: SWA and tail tokens are state-space-like — they
depend only on the current position, not on a paged history. No
block table; a flat [max_requests, ...] tensor.
CSA b-stream lifecycle (paper eq.11-12):
After a CSA flush, the current a-stream (tail_ka/tail_za) becomes
the next flush's b-stream input (tail_kb/tail_zb). Both are sized
at m entries, not m-1. On first flush, tail_zb is filled with -1e9
so the softmax in the compressor naturally masks out the b-stream
(exp(-inf) = 0).
"""
from __future__ import annotations
import torch
@@ -22,15 +29,13 @@ class StateCachePool:
swa_rope: [n_win, rope_dim] BF16 RoPE'd half
swa_inv: [n_win] FP32 per-token inv scale
swa_pos: [n_win] int32 — absolute position
of each window slot (-1 if invalid)
swa_head: scalar int32 — ring buffer write head
tail_ka: [tail_size, head_dim] BF16 raw — pending tokens
not yet compressed
tail_za: [tail_size, head_dim] BF16 — compression weights
(Z stream for CSA, single Z for HCA)
tail_kb: [tail_size, head_dim] BF16 — second stream (CSA only)
tail_zb: [tail_size, head_dim] BF16 — second Z stream (CSA only)
tail_len: scalar int32 — how many tail entries are valid
tail_ka: [m_a, head_dim] BF16 — current a-stream tokens
tail_za: [m_a, head_dim] BF16 — current a-stream Z weights
tail_kb: [m_b, head_dim] BF16 — previous a-stream kept as b-input (CSA only)
tail_zb: [m_b, head_dim] BF16 — previous Z b-stream (CSA only, init to -1e9)
tail_len: scalar int32 — how many entries in a-stream are valid
"""
def __init__(
@@ -49,33 +54,31 @@ class StateCachePool:
rd = schema.rope_dim
fp8 = hd - rd
# SWA window — circular within each slot. Layer's attention
# kernel uses swa_pos to mask invalid entries.
# SWA window — circular within each slot.
self.swa_fp8 = torch.zeros((mr, nw, fp8), dtype=torch.uint8, device=device)
self.swa_rope = torch.zeros((mr, nw, rd), dtype=torch.bfloat16, device=device)
self.swa_inv = torch.ones((mr, nw), dtype=torch.float32, device=device)
self.swa_pos = torch.full((mr, nw), -1, dtype=torch.int32, device=device)
# Next write position within each slot's ring buffer.
self.swa_head = torch.zeros((mr,), dtype=torch.int32, device=device)
# Tail buffer — only non-empty for compressed layers.
tail = schema.tail_buffer_size
if tail > 0:
# For CSA we need two streams (Ca/Cb, Za/Zb) since the
# compressor uses overlapping pairs. HCA only needs one
# stream. Store both; HCA leaves the b-channel zero.
self.tail_ka = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
self.tail_za = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
if schema.attn_type == AttentionType.CSA:
self.tail_kb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
self.tail_zb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
# Tail buffer — only for compressed layers.
m_a = schema.tail_buffer_size_a # m (CSA) or m' (HCA)
m_b = schema.tail_buffer_size_b # m (CSA only)
if m_a > 0:
self.tail_ka = torch.zeros((mr, m_a, hd), dtype=torch.bfloat16, device=device)
self.tail_za = torch.zeros((mr, m_a, hd), dtype=torch.bfloat16, device=device)
self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device)
if m_b > 0: # CSA: need b-stream
self.tail_kb = torch.zeros((mr, m_b, hd), dtype=torch.bfloat16, device=device)
# Paper §3.5.1: Z^b padded with -inf at first flush.
# Init to -1e9 so softmax naturally masks b-stream on first flush.
self.tail_zb = torch.full((mr, m_b, hd), -1e9, dtype=torch.bfloat16, device=device)
else:
self.tail_kb = None
self.tail_zb = None
self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device)
else:
self.tail_ka = self.tail_kb = None
self.tail_za = self.tail_zb = None
self.tail_ka = self.tail_za = None
self.tail_kb = self.tail_zb = None
self.tail_len = None
def reset_slot(self, slot: int) -> None:
@@ -84,6 +87,9 @@ class StateCachePool:
self.swa_head[slot] = 0
if self.tail_len is not None:
self.tail_len[slot] = 0
# Re-init tail_zb to -1e9 for CSA (paper §3.5.1 first-flush mask)
if self.tail_zb is not None:
self.tail_zb[slot].fill_(-1e9)
def memory_bytes(self) -> int:
"""Total GPU memory used by this pool."""

View File

@@ -0,0 +1,450 @@
// flush_write.cu — Quantize and scatter compressed entries into paged KV pool.
//
// Two kernel variants:
// flush_write_csa_kernel: writes compressed entry + FP4 indexer key
// flush_write_hca_kernel: writes compressed entry only (no indexer)
//
// Both do BF16 → FP8 (E4M3) quantization with per-token amax for the
// non-RoPE half, and write the RoPE half as-is BF16.
//
// One block per request. Each block handles writing ONE compressed entry
// per flush. At decode (B small, 1 entry/flush) this is 1-16 CTAs.
// At prefill (B up to 128), this is up to 128 CTAs — good occupancy.
//
// Blackwell SM100: 128 threads per block for the FP8 quantize loop
// covers head_dim=512 with 4 elements per thread. The FP4 indexer
// quantize uses 64 threads (indexer_head_dim=128, 2 elements/thread).
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAException.h>
#include <limits>
// ---- Warp-level reductions ----
__device__ __forceinline__ float warp_reduce_max(float val) {
for (int offset = 16; offset > 0; offset >>= 1) {
float other = __shfl_down_sync(0xffffffff, val, offset);
val = fmaxf(val, fabsf(other));
}
return val;
}
__device__ __forceinline__ float warp_reduce_sum(float val) {
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
// ---- Block-level amax (128 threads = 4 warps) ----
__device__ __forceinline__ float block_reduce_amax(float val, int n_warps) {
float warp_amax = warp_reduce_max(val);
__shared__ float smem[4];
if (threadIdx.x % 32 == 0) {
smem[threadIdx.x / 32] = warp_amax;
}
__syncthreads();
float result = 0.0f;
if (threadIdx.x < 32) {
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f;
result = warp_reduce_max(v);
}
__syncthreads();
return result;
}
// ---- NVFP4 quantization for indexer keys ----
// 16-element groups, one E4M3 scale per group.
// FP4 E2M1 has 6 possible values: 0, 2, 4, 6, 8, 10, 12, 14 (shifted).
// We use a simplified approach: group amax / 6.0 -> scale,
// quantize each element to nearest of {0,1,2,3,4,5,6} * scale.
__device__ __forceinline__ void quantize_fp4_group(
const __nv_bfloat16* __restrict__ input, // 16 elements
uint8_t* __restrict__ output, // 8 bytes (2 FP4 per byte)
uint8_t* __restrict__ scale_out // 1 FP8 E4M3 scale
) {
// Compute group amax
float amax = 0.0f;
for (int i = 0; i < 16; i++) {
amax = fmaxf(amax, fabsf(__bfloat162float(input[i])));
}
// FP4 E2M1 has max representable = 6.0 (before scaling)
float scale = amax / 6.0f;
if (scale < 1e-12f) scale = 1e-12f;
float inv_scale = scale;
// Write scale as FP8 E4M3
__nv_fp8_e4m3 fp8_scale;
fp8_scale = __nv_fp8_e4m3(scale);
*scale_out = fp8_scale.__x;
// Quantize 16 elements to FP4 E2M1, pack 2 per byte
for (int i = 0; i < 8; i++) {
float v0 = __bfloat162float(input[2 * i]) / inv_scale;
float v1 = __bfloat162float(input[2 * i + 1]) / inv_scale;
// Clamp to [0, 6] and round to nearest int
int q0 = (int)roundf(fmaxf(0.0f, fminf(6.0f, v0)));
int q1 = (int)roundf(fmaxf(0.0f, fminf(6.0f, v1)));
// Pack: low nibble = element 0, high nibble = element 1
output[i] = (uint8_t)((q1 << 4) | q0);
}
}
// ===========================================================================
// CSA flush write kernel
// ===========================================================================
__global__ void flush_write_csa_kernel(
// Inputs
const __nv_bfloat16* __restrict__ entry, // [B, head_dim] BF16
const __nv_bfloat16* __restrict__ indexer_key, // [B, indexer_head_dim] BF16
const bool* __restrict__ valid_mask, // [B]
const int32_t* __restrict__ request_slots, // [B]
const int32_t* __restrict__ positions, // [B]
const int32_t* __restrict__ block_table, // [B, max_logical_blocks]
// Outputs — paged pool tensors, mutated in place
uint8_t* __restrict__ entries_fp8, // [num_blocks, epb, fp8_dim]
__nv_bfloat16* __restrict__ entries_rope, // [num_blocks, epb, rope_dim]
float* __restrict__ inv_scale, // [num_blocks, epb]
uint8_t* __restrict__ indexer_keys_fp4, // [num_blocks, epb, ihd/2]
uint8_t* __restrict__ indexer_scale, // [num_blocks, epb, ihd/16]
// Geometry
int entries_per_block, int m, int rope_dim,
int head_dim, int indexer_head_dim, int max_logical_blocks
) {
int b = blockIdx.x;
if (!valid_mask[b]) return; // Early exit for no-op requests.
// Resolve destination slot in the paged pool.
int pos = positions[b];
int entry_idx = pos / m; // which compressed entry index
int logical_block = entry_idx / entries_per_block;
int slot_in_block = entry_idx % entries_per_block;
int phys_block = block_table[b * max_logical_blocks + logical_block];
int fp8_dim = head_dim - rope_dim;
int tid = threadIdx.x;
int n_threads = blockDim.x; // 128
int n_warps = n_threads / 32;
// ---- Step 1: Compute amax over non-RoPE half ----
float local_amax = 0.0f;
for (int i = tid; i < fp8_dim; i += n_threads) {
float v = fabsf(__bfloat162float(entry[b * head_dim + i]));
local_amax = fmaxf(local_amax, v);
}
float block_amax = block_reduce_amax(local_amax, n_warps);
// ---- Step 2: Write inv_scale ----
__shared__ float s_inv_scale;
if (tid == 0) {
float scale = (block_amax > 1e-12f) ? (block_amax / 448.0f) : 1e-12f;
s_inv_scale = scale;
inv_scale[phys_block * entries_per_block + slot_in_block] = scale;
}
__syncthreads();
// ---- Step 3: Quantize and write FP8 half ----
float inv_s = s_inv_scale;
for (int i = tid; i < fp8_dim; i += n_threads) {
float v = __bfloat162float(entry[b * head_dim + i]);
float quantized = v / inv_s;
quantized = fmaxf(-448.0f, fminf(448.0f, quantized));
__nv_fp8_e4m3 fp8_val;
fp8_val = __nv_fp8_e4m3(quantized);
entries_fp8[(phys_block * entries_per_block + slot_in_block) * fp8_dim + i] = fp8_val.__x;
}
// ---- Step 4: Write BF16 RoPE half ----
for (int i = tid; i < rope_dim; i += n_threads) {
entries_rope[(phys_block * entries_per_block + slot_in_block) * rope_dim + i]
= entry[b * head_dim + fp8_dim + i];
}
// ---- Step 5: FP4 quantize and write indexer key ----
// 16 elements per group, one FP8 E4M3 scale per group.
// Process groups in parallel across threads.
int n_groups = indexer_head_dim / 16;
int n_bytes = indexer_head_dim / 2; // 2 FP4 per byte
int n_scales = n_groups;
for (int g = tid; g < n_groups; g += n_threads) {
// Gather 16 BF16 values for this group
__nv_bfloat16 group_in[16];
for (int j = 0; j < 16; j++) {
group_in[j] = indexer_key[b * indexer_head_dim + g * 16 + j];
}
uint8_t group_out[8];
uint8_t group_scale;
quantize_fp4_group(group_in, group_out, &group_scale);
// Write 8 packed bytes
int byte_offset = (phys_block * entries_per_block + slot_in_block) * n_bytes + g * 8;
for (int j = 0; j < 8; j++) {
indexer_keys_fp4[byte_offset + j] = group_out[j];
}
// Write scale
int scale_offset = (phys_block * entries_per_block + slot_in_block) * n_scales + g;
indexer_scale[scale_offset] = group_scale;
}
}
// ===========================================================================
// HCA flush write kernel (no indexer)
// ===========================================================================
__global__ void flush_write_hca_kernel(
const __nv_bfloat16* __restrict__ entry,
const bool* __restrict__ valid_mask,
const int32_t* __restrict__ request_slots,
const int32_t* __restrict__ positions,
const int32_t* __restrict__ block_table,
uint8_t* __restrict__ entries_fp8,
__nv_bfloat16* __restrict__ entries_rope,
float* __restrict__ inv_scale,
int entries_per_block, int m, int rope_dim,
int head_dim, int max_logical_blocks
) {
int b = blockIdx.x;
if (!valid_mask[b]) return;
int pos = positions[b];
int entry_idx = pos / m;
int logical_block = entry_idx / entries_per_block;
int slot_in_block = entry_idx % entries_per_block;
int phys_block = block_table[b * max_logical_blocks + logical_block];
int fp8_dim = head_dim - rope_dim;
int tid = threadIdx.x;
int n_threads = blockDim.x;
int n_warps = n_threads / 32;
// Amax reduction
float local_amax = 0.0f;
for (int i = tid; i < fp8_dim; i += n_threads) {
float v = fabsf(__bfloat162float(entry[b * head_dim + i]));
local_amax = fmaxf(local_amax, v);
}
float block_amax = block_reduce_amax(local_amax, n_warps);
__shared__ float s_inv_scale;
if (tid == 0) {
float scale = (block_amax > 1e-12f) ? (block_amax / 448.0f) : 1e-12f;
s_inv_scale = scale;
inv_scale[phys_block * entries_per_block + slot_in_block] = scale;
}
__syncthreads();
// FP8 quantize + write
float inv_s = s_inv_scale;
for (int i = tid; i < fp8_dim; i += n_threads) {
float v = __bfloat162float(entry[b * head_dim + i]);
float quantized = v / inv_s;
quantized = fmaxf(-448.0f, fminf(448.0f, quantized));
__nv_fp8_e4m3 fp8_val;
fp8_val = __nv_fp8_e4m3(quantized);
entries_fp8[(phys_block * entries_per_block + slot_in_block) * fp8_dim + i] = fp8_val.__x;
}
// BF16 RoPE half
for (int i = tid; i < rope_dim; i += n_threads) {
entries_rope[(phys_block * entries_per_block + slot_in_block) * rope_dim + i]
= entry[b * head_dim + fp8_dim + i];
}
}
// ===========================================================================
// State rotation kernels (in-place, single-kernel launches)
// ===========================================================================
// CSA: after flush, rotate a-stream -> b-stream, clear a-stream
__global__ void csa_rotate_state_kernel(
const bool* __restrict__ valid_mask, // [B]
const int32_t* __restrict__ request_slots, // [B]
// State cache tensors — mutated in place
__nv_bfloat16* __restrict__ tail_ka, // [max_req, m, head_dim]
__nv_bfloat16* __restrict__ tail_za,
__nv_bfloat16* __restrict__ tail_kb,
__nv_bfloat16* __restrict__ tail_zb,
int32_t* __restrict__ tail_len, // [max_req]
int m, int head_dim, int max_requests
) {
int b = blockIdx.x;
if (!valid_mask[b]) return;
int slot = request_slots[b];
int tid = threadIdx.x;
int n_threads = blockDim.x;
// Rotate: kb <- ka, zb <- za (current a-stream becomes next b-stream)
int total = m * head_dim;
for (int i = tid; i < total; i += n_threads) {
tail_kb[slot * total + i] = tail_ka[slot * total + i];
tail_zb[slot * total + i] = tail_za[slot * total + i];
}
// Clear a-stream (zero out) and reset tail_len
if (tid == 0) {
tail_len[slot] = 0;
}
for (int i = tid; i < total; i += n_threads) {
tail_ka[slot * total + i] = __float2bfloat16(0.0f);
tail_za[slot * total + i] = __float2bfloat16(0.0f);
}
}
// HCA: after flush, just clear a-stream and reset tail_len
__global__ void hca_reset_state_kernel(
const bool* __restrict__ valid_mask,
const int32_t* __restrict__ request_slots,
__nv_bfloat16* __restrict__ tail_ka,
__nv_bfloat16* __restrict__ tail_za,
int32_t* __restrict__ tail_len,
int m, int head_dim, int max_requests
) {
int b = blockIdx.x;
if (!valid_mask[b]) return;
int slot = request_slots[b];
int tid = threadIdx.x;
int n_threads = blockDim.x;
int total = m * head_dim;
if (tid == 0) {
tail_len[slot] = 0;
}
for (int i = tid; i < total; i += n_threads) {
tail_ka[slot * total + i] = __float2bfloat16(0.0f);
tail_za[slot * total + i] = __float2bfloat16(0.0f);
}
}
// ===========================================================================
// PyTorch bindings
// ===========================================================================
void flush_write_csa_cuda(
torch::Tensor entry,
torch::Tensor indexer_key,
torch::Tensor valid_mask,
torch::Tensor request_slots,
torch::Tensor positions,
torch::Tensor block_table,
torch::Tensor entries_fp8,
torch::Tensor entries_rope,
torch::Tensor inv_scale,
torch::Tensor indexer_keys_fp4,
torch::Tensor indexer_scale,
int64_t entries_per_block, int64_t m, int64_t rope_dim,
int64_t head_dim, int64_t indexer_head_dim
) {
int B = entry.size(0);
int max_logical_blocks = block_table.size(1);
int threads = 128;
flush_write_csa_kernel<<<B, threads>>>(
reinterpret_cast<const __nv_bfloat16*>(entry.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(indexer_key.data_ptr<at::BFloat16>()),
valid_mask.data_ptr<bool>(),
request_slots.data_ptr<int32_t>(),
positions.data_ptr<int32_t>(),
block_table.data_ptr<int32_t>(),
entries_fp8.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(entries_rope.data_ptr<at::BFloat16>()),
inv_scale.data_ptr<float>(),
indexer_keys_fp4.data_ptr<uint8_t>(),
indexer_scale.data_ptr<uint8_t>(),
(int)entries_per_block, (int)m, (int)rope_dim,
(int)head_dim, (int)indexer_head_dim, max_logical_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
}
void flush_write_hca_cuda(
torch::Tensor entry,
torch::Tensor valid_mask,
torch::Tensor request_slots,
torch::Tensor positions,
torch::Tensor block_table,
torch::Tensor entries_fp8,
torch::Tensor entries_rope,
torch::Tensor inv_scale,
int64_t entries_per_block, int64_t m, int64_t rope_dim,
int64_t head_dim
) {
int B = entry.size(0);
int max_logical_blocks = block_table.size(1);
int threads = 128;
flush_write_hca_kernel<<<B, threads>>>(
reinterpret_cast<const __nv_bfloat16*>(entry.data_ptr<at::BFloat16>()),
valid_mask.data_ptr<bool>(),
request_slots.data_ptr<int32_t>(),
positions.data_ptr<int32_t>(),
block_table.data_ptr<int32_t>(),
entries_fp8.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(entries_rope.data_ptr<at::BFloat16>()),
inv_scale.data_ptr<float>(),
(int)entries_per_block, (int)m, (int)rope_dim,
(int)head_dim, max_logical_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
}
void csa_rotate_state_cuda(
torch::Tensor valid_mask,
torch::Tensor request_slots,
torch::Tensor tail_ka,
torch::Tensor tail_za,
torch::Tensor tail_kb,
torch::Tensor tail_zb,
torch::Tensor tail_len,
int64_t m, int64_t head_dim
) {
int B = valid_mask.size(0);
int threads = 128;
csa_rotate_state_kernel<<<B, threads>>>(
valid_mask.data_ptr<bool>(),
request_slots.data_ptr<int32_t>(),
reinterpret_cast<__nv_bfloat16*>(tail_ka.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(tail_za.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(tail_kb.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(tail_zb.data_ptr<at::BFloat16>()),
tail_len.data_ptr<int32_t>(),
(int)m, (int)head_dim, 0 // max_requests unused in kernel
);
C10_CUDA_CHECK(cudaGetLastError());
}
void hca_reset_state_cuda(
torch::Tensor valid_mask,
torch::Tensor request_slots,
torch::Tensor tail_ka,
torch::Tensor tail_za,
torch::Tensor tail_len,
int64_t m, int64_t head_dim
) {
int B = valid_mask.size(0);
int threads = 128;
hca_reset_state_kernel<<<B, threads>>>(
valid_mask.data_ptr<bool>(),
request_slots.data_ptr<int32_t>(),
reinterpret_cast<__nv_bfloat16*>(tail_ka.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(tail_za.data_ptr<at::BFloat16>()),
tail_len.data_ptr<int32_t>(),
(int)m, (int)head_dim, 0
);
C10_CUDA_CHECK(cudaGetLastError());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("flush_write_csa", &flush_write_csa_cuda, "CSA flush write kernel");
m.def("flush_write_hca", &flush_write_hca_cuda, "HCA flush write kernel");
m.def("csa_rotate_state", &csa_rotate_state_cuda, "CSA state rotation kernel");
m.def("hca_reset_state", &hca_reset_state_cuda, "HCA state reset kernel");
}