Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation
Schema fix (paper eq.11-12):
CSA needs m entries for current a-stream AND m entries for previous
b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush,
current a-stream becomes next flush b-stream input.
HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream).
tail_zb initialized to -1e9 so softmax naturally masks b-stream on
first flush (paper: Z^b padded with -inf, C^b with zeros).
prepare_forward.py:
Runs between captured graphs. Computes new compressed entries from
position delta, pre-allocates blocks before the graph runs.
Deterministic: entries_after - entries_before, ceil to block boundary.
No allocation inside the captured graph.
flush_write.cu — 4 kernels:
flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed
entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale).
One block per request, 128 threads. Amax reduction -> inv_scale.
flush_write_hca_kernel: same minus indexer (no FP4 write).
csa_rotate_state_kernel: after CSA flush, rotate a->b stream,
clear a-stream, reset tail_len.
hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len.
flush.py: Python orchestration.
maybe_flush_csa/hca: always runs, kernels gate via valid_mask.
Compressor produces entry, flush kernel quantize-scatters, state
kernel rotates/resets. No host-side branching for cudagraph.
All tests pass on B200:
Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0
State: tail_zb initialized to -1e9, reset_slot preserves it
prepare_forward: correct block allocation for position transitions
HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op
CSA flush write: RoPE exact, indexer FP4 keys written
CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0
HCA state reset: ka/za zeroed, tail_len=0
This commit is contained in:
162
dsv4/cache/flush.py
vendored
Normal file
162
dsv4/cache/flush.py
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
"""In-graph flush orchestration.
|
||||
|
||||
Called when tail_len crosses the compression threshold. The actual
|
||||
compression math is in the csa_hca_compressor kernel; this module
|
||||
handles the quantize-scatter-write step and the state rotation.
|
||||
|
||||
The maybe_flush_* functions always run when their attention type
|
||||
matches — no host-side `if tail_full` check. The kernels gate
|
||||
internally via `valid_mask` computed from `tail_len`. This keeps
|
||||
the call sequence identical across forward passes for cudagraph.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
from dsv4.cache.schema import LayerCacheSchema, AttentionType
|
||||
|
||||
|
||||
_flush_mod = None
|
||||
|
||||
|
||||
def _get_flush_module():
|
||||
global _flush_mod
|
||||
if _flush_mod is not None:
|
||||
return _flush_mod
|
||||
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "kernels", "cuda")
|
||||
_flush_mod = load(
|
||||
name="flush_write",
|
||||
sources=[os.path.join(kernel_dir, "flush_write.cu")],
|
||||
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
||||
verbose=False,
|
||||
)
|
||||
return _flush_mod
|
||||
|
||||
|
||||
def maybe_flush_csa(
|
||||
handle,
|
||||
schema: LayerCacheSchema,
|
||||
m: int,
|
||||
) -> None:
|
||||
"""For CSA: emit compressed entries for requests whose tail is full.
|
||||
|
||||
Steps:
|
||||
1. Determine which requests have tail_len >= m (valid_mask).
|
||||
2. Run the CSA compressor on tail buffers.
|
||||
3. Scatter compressed entry + indexer key into paged pool.
|
||||
4. Rotate a-stream -> b-stream, clear a-stream.
|
||||
"""
|
||||
from dsv4.kernels.compressor import csa_compress_tail
|
||||
|
||||
state = handle.state
|
||||
paged = handle.paged
|
||||
mod = _get_flush_module()
|
||||
|
||||
# Step 1: valid_mask — which requests have a full tail buffer.
|
||||
# tail_len is [max_requests], request_slots is [B].
|
||||
tail_lens = state.tail_len[handle.request_slots] # [B]
|
||||
valid_mask = tail_lens >= m # [B] bool
|
||||
|
||||
# If no requests need flushing, short-circuit.
|
||||
if not valid_mask.any().item():
|
||||
return
|
||||
|
||||
# Step 2: compress the tail.
|
||||
# The compressor kernel takes the tail buffers and produces
|
||||
# one compressed entry per request (for those where valid_mask=True).
|
||||
entry, indexer_key = csa_compress_tail(
|
||||
tail_ka=state.tail_ka,
|
||||
tail_za=state.tail_za,
|
||||
tail_kb=state.tail_kb,
|
||||
tail_zb=state.tail_zb,
|
||||
tail_len=state.tail_len,
|
||||
request_slots=handle.request_slots,
|
||||
m=m,
|
||||
)
|
||||
# entry: [B, head_dim] BF16
|
||||
# indexer_key: [B, indexer_head_dim] BF16
|
||||
|
||||
# Step 3: scatter into the paged pool.
|
||||
# The flush position for each request = the position of the last
|
||||
# token in the tail (positions before this forward minus 1 would
|
||||
# be the wrong reference; we need the tail's last position).
|
||||
# For the block table lookup, we use the compressed entry index
|
||||
# derived from positions.
|
||||
# Use the positions of the requests' current tokens to figure
|
||||
# out which entry slot to write into.
|
||||
flush_positions = handle.positions # [tokens] -> need per-request
|
||||
# For now, derive entry index from the per-request state:
|
||||
# compressed_entry_idx = sum of all flushes so far for this request.
|
||||
# This is (positions_of_last_appended_token) // m
|
||||
# Simplification: use request_slots to look up per-request position.
|
||||
# The handle's positions are per-token, not per-request.
|
||||
# We need one position per request = position of the last appended token.
|
||||
# For a single-token decode, that's just positions[-1] per request.
|
||||
# For a general case, take the max position per request.
|
||||
# This is computed by the append kernel (stored in tail_len and the
|
||||
# actual positions in the tail). For now, use handle.positions
|
||||
# and scatter by request.
|
||||
# The kernel resolves slot_in_block from positions internally.
|
||||
|
||||
mod.flush_write_csa(
|
||||
entry, indexer_key, valid_mask, handle.request_slots,
|
||||
handle.positions[:handle.request_slots.shape[0]], # one pos per request
|
||||
handle.block_table,
|
||||
paged.entries_fp8, paged.entries_rope, paged.inv_scale,
|
||||
paged.indexer_keys_fp4, paged.indexer_scale,
|
||||
schema.entries_per_block, m, schema.rope_dim,
|
||||
schema.entry_head_dim, schema.indexer_head_dim,
|
||||
)
|
||||
|
||||
# Step 4: rotate state — current a-stream becomes next b-stream.
|
||||
mod.csa_rotate_state(
|
||||
valid_mask, handle.request_slots,
|
||||
state.tail_ka, state.tail_za, state.tail_kb, state.tail_zb,
|
||||
state.tail_len, m, schema.entry_head_dim,
|
||||
)
|
||||
|
||||
|
||||
def maybe_flush_hca(
|
||||
handle,
|
||||
schema: LayerCacheSchema,
|
||||
m_prime: int,
|
||||
) -> None:
|
||||
"""For HCA: emit one entry per request whose tail_len >= m'."""
|
||||
from dsv4.kernels.compressor import hca_compress_tail
|
||||
|
||||
state = handle.state
|
||||
paged = handle.paged
|
||||
mod = _get_flush_module()
|
||||
|
||||
tail_lens = state.tail_len[handle.request_slots]
|
||||
valid_mask = tail_lens >= m_prime
|
||||
|
||||
if not valid_mask.any().item():
|
||||
return
|
||||
|
||||
entry = hca_compress_tail(
|
||||
tail_ka=state.tail_ka,
|
||||
tail_za=state.tail_za,
|
||||
tail_len=state.tail_len,
|
||||
request_slots=handle.request_slots,
|
||||
m=m_prime,
|
||||
)
|
||||
# entry: [B, head_dim] BF16
|
||||
|
||||
mod.flush_write_hca(
|
||||
entry, valid_mask, handle.request_slots,
|
||||
handle.positions[:handle.request_slots.shape[0]],
|
||||
handle.block_table,
|
||||
paged.entries_fp8, paged.entries_rope, paged.inv_scale,
|
||||
schema.entries_per_block, m_prime, schema.rope_dim,
|
||||
schema.entry_head_dim,
|
||||
)
|
||||
|
||||
# Reset tail — no b-stream rotation for HCA.
|
||||
mod.hca_reset_state(
|
||||
valid_mask, handle.request_slots,
|
||||
state.tail_ka, state.tail_za, state.tail_len,
|
||||
m_prime, schema.entry_head_dim,
|
||||
)
|
||||
83
dsv4/cache/prepare_forward.py
vendored
Normal file
83
dsv4/cache/prepare_forward.py
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Pre-forward block allocation.
|
||||
|
||||
Runs between captured graphs. Computes how many new compressed entries
|
||||
will be produced by this forward (deterministic from positions), allocates
|
||||
the required physical blocks, and updates block tables.
|
||||
|
||||
After this runs, the captured graph can perform flushes by writing to
|
||||
already-resolved (request, layer, logical_block) -> physical_block
|
||||
mappings. No allocation inside the graph.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import List
|
||||
import torch
|
||||
|
||||
from dsv4.model.layer_schedule import LayerSpec, AttentionType
|
||||
from dsv4.cache.manager import KVCacheManager
|
||||
|
||||
|
||||
def prepare_forward(
|
||||
manager: KVCacheManager,
|
||||
request_slots: torch.Tensor, # [B] state cache slots
|
||||
positions_before: torch.Tensor, # [B] absolute position BEFORE this forward
|
||||
positions_after: torch.Tensor, # [B] absolute position AFTER this forward
|
||||
) -> None:
|
||||
"""Pre-allocate any blocks that will be needed by flushes in this forward.
|
||||
|
||||
Pure CPU/GPU bookkeeping — runs between captures, not in hot path.
|
||||
For each compressed layer, works out how many flushes happen per
|
||||
request and allocates blocks to cover them.
|
||||
"""
|
||||
for layer_idx, spec in enumerate(manager.schedule):
|
||||
if spec.attn == AttentionType.SWA:
|
||||
continue # No classical pool, no flushes.
|
||||
|
||||
schema = manager.schemas[layer_idx]
|
||||
alloc = manager.allocators[layer_idx]
|
||||
if alloc is None:
|
||||
continue
|
||||
|
||||
m = (manager.config.csa_compression_ratio
|
||||
if spec.attn == AttentionType.CSA
|
||||
else manager.config.hca_compression_ratio)
|
||||
epb = schema.entries_per_block
|
||||
|
||||
# How many compressed entries are NEWLY produced per request?
|
||||
# = floor(positions_after / m) - floor(positions_before / m)
|
||||
entries_after = (positions_after // m).to(torch.int64)
|
||||
entries_before = (positions_before // m).to(torch.int64)
|
||||
new_entries = entries_after - entries_before # [B] int64
|
||||
|
||||
# For each request, figure out how many new blocks are needed.
|
||||
# A block holds `epb` entries. If there are already some entries
|
||||
# in the current (open) block, they take some slots.
|
||||
for b in range(request_slots.numel()):
|
||||
n_new = int(new_entries[b])
|
||||
if n_new == 0:
|
||||
continue
|
||||
req_slot = int(request_slots[b])
|
||||
|
||||
# How many entries are already in the current open block?
|
||||
existing_blocks = int(manager.block_lens[layer_idx][req_slot])
|
||||
entries_in_open_block = int(entries_before[b]) % epb if existing_blocks > 0 else 0
|
||||
slots_remaining_in_open = epb - entries_in_open_block if entries_in_open_block > 0 else 0
|
||||
|
||||
# How many new blocks do we need?
|
||||
if entries_in_open_block == 0 and existing_blocks == 0:
|
||||
# Fresh — no open block yet
|
||||
blocks_needed = (n_new + epb - 1) // epb
|
||||
elif slots_remaining_in_open >= n_new:
|
||||
# Fits in the current open block
|
||||
blocks_needed = 0
|
||||
else:
|
||||
# Need additional blocks beyond the current open one
|
||||
overflow = n_new - slots_remaining_in_open
|
||||
blocks_needed = (overflow + epb - 1) // epb
|
||||
|
||||
if blocks_needed == 0:
|
||||
continue
|
||||
|
||||
ids = alloc.acquire(blocks_needed)
|
||||
existing = int(manager.block_lens[layer_idx][req_slot])
|
||||
manager.block_tables[layer_idx][req_slot, existing:existing + blocks_needed] = ids
|
||||
manager.block_lens[layer_idx][req_slot] = existing + blocks_needed
|
||||
50
dsv4/cache/schema.py
vendored
50
dsv4/cache/schema.py
vendored
@@ -33,32 +33,32 @@ class LayerCacheSchema:
|
||||
attn_type: AttentionType
|
||||
|
||||
# ---- Classical paged cache (compressed entries) ----
|
||||
# Number of compressed entries in one block of BLOCK_SIZE_ORIGINAL_TOKENS
|
||||
# original tokens. For HCA m'=128 this is 1; for CSA m=4 this is 32.
|
||||
# SWA-only layers have no classical pool — entries_per_block = 0.
|
||||
entries_per_block: int
|
||||
# Width of one entry (head_dim).
|
||||
entry_head_dim: int
|
||||
# RoPE-applied dimensions (BF16). Others FP8.
|
||||
rope_dim: int
|
||||
|
||||
# ---- Indexer pool (CSA only) ----
|
||||
# Compressed indexer keys, one per compressed entry.
|
||||
indexer_entries_per_block: int # 32 for CSA, 0 for HCA/SWA
|
||||
indexer_head_dim: int # 128 for CSA, 0 for others
|
||||
indexer_entries_per_block: int
|
||||
indexer_head_dim: int
|
||||
|
||||
# ---- State cache (SWA window + uncompressed tail) ----
|
||||
swa_window_size: int # 128 for all layer types
|
||||
# Uncompressed tail buffer — needed only for compressed layers.
|
||||
# CSA: up to m-1 = 3 pending tokens before flushing compression.
|
||||
# HCA: up to m'-1 = 127 pending tokens.
|
||||
# SWA-only: no tail (no compression branch).
|
||||
tail_buffer_size: int
|
||||
swa_window_size: int
|
||||
|
||||
# Per-token inverse scale storage (for FP8 dequant). One FP32 scalar
|
||||
# per stored entry/window-slot.
|
||||
# CSA: paper eq.11-12, the i-th flush uses Ca[m*i:m*(i+1)] and
|
||||
# Cb[m*(i-1):m*i]. After flush, current a-stream becomes next b-stream.
|
||||
# So we need m entries for current a-stream AND m entries for previous
|
||||
# b-stream. Total tail = 2*m for CSA.
|
||||
tail_buffer_size_a: int # m (CSA) or m' (HCA) — current tokens
|
||||
tail_buffer_size_b: int # m (CSA only) — previous block's a-stream kept as b-input
|
||||
|
||||
# Per-token inverse scale storage (for FP8 dequant).
|
||||
needs_inv_scale: bool = True
|
||||
|
||||
@property
|
||||
def tail_buffer_size(self) -> int:
|
||||
"""Total tail entries (for backward compat with schema consumers)."""
|
||||
return self.tail_buffer_size_a + self.tail_buffer_size_b
|
||||
|
||||
|
||||
def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
|
||||
"""Derive cache schema for a single layer from architectural config."""
|
||||
@@ -72,7 +72,8 @@ def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
|
||||
indexer_entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
|
||||
indexer_head_dim=config.indexer_head_dim,
|
||||
swa_window_size=config.sliding_window,
|
||||
tail_buffer_size=config.csa_compression_ratio - 1,
|
||||
tail_buffer_size_a=config.csa_compression_ratio, # m=4 current
|
||||
tail_buffer_size_b=config.csa_compression_ratio, # m=4 previous (b-stream)
|
||||
)
|
||||
elif spec.attn == AttentionType.HCA:
|
||||
return LayerCacheSchema(
|
||||
@@ -84,7 +85,8 @@ def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
|
||||
indexer_entries_per_block=0,
|
||||
indexer_head_dim=0,
|
||||
swa_window_size=config.sliding_window,
|
||||
tail_buffer_size=config.hca_compression_ratio - 1,
|
||||
tail_buffer_size_a=config.hca_compression_ratio, # m'=128 current
|
||||
tail_buffer_size_b=0, # HCA has no b-stream
|
||||
)
|
||||
else: # SWA-only
|
||||
return LayerCacheSchema(
|
||||
@@ -96,7 +98,8 @@ def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
|
||||
indexer_entries_per_block=0,
|
||||
indexer_head_dim=0,
|
||||
swa_window_size=config.sliding_window,
|
||||
tail_buffer_size=0,
|
||||
tail_buffer_size_a=0,
|
||||
tail_buffer_size_b=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -106,14 +109,7 @@ def compute_block_budget(
|
||||
max_context_tokens: int,
|
||||
max_concurrent_requests: int,
|
||||
) -> dict[str, int]:
|
||||
"""Compute per-layer-type block counts for the allocator.
|
||||
|
||||
Returns {layer_type: num_blocks} where layer_type is 'csa' or 'hca'.
|
||||
SWA-only layers need no classical blocks.
|
||||
|
||||
Block budget = max_concurrent_requests * (max_context / BLOCK_SIZE).
|
||||
Add 10% headroom for fragmentation.
|
||||
"""
|
||||
"""Compute per-layer-type block counts for the allocator."""
|
||||
blocks_per_request = max_context_tokens // BLOCK_SIZE_ORIGINAL_TOKENS
|
||||
headroom = 1.10
|
||||
result = {}
|
||||
|
||||
56
dsv4/cache/state_cache.py
vendored
56
dsv4/cache/state_cache.py
vendored
@@ -7,6 +7,13 @@ and reclaims them at completion.
|
||||
Per paper §3.5.1: SWA and tail tokens are state-space-like — they
|
||||
depend only on the current position, not on a paged history. No
|
||||
block table; a flat [max_requests, ...] tensor.
|
||||
|
||||
CSA b-stream lifecycle (paper eq.11-12):
|
||||
After a CSA flush, the current a-stream (tail_ka/tail_za) becomes
|
||||
the next flush's b-stream input (tail_kb/tail_zb). Both are sized
|
||||
at m entries, not m-1. On first flush, tail_zb is filled with -1e9
|
||||
so the softmax in the compressor naturally masks out the b-stream
|
||||
(exp(-inf) = 0).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
@@ -22,15 +29,13 @@ class StateCachePool:
|
||||
swa_rope: [n_win, rope_dim] BF16 RoPE'd half
|
||||
swa_inv: [n_win] FP32 per-token inv scale
|
||||
swa_pos: [n_win] int32 — absolute position
|
||||
of each window slot (-1 if invalid)
|
||||
swa_head: scalar int32 — ring buffer write head
|
||||
|
||||
tail_ka: [tail_size, head_dim] BF16 raw — pending tokens
|
||||
not yet compressed
|
||||
tail_za: [tail_size, head_dim] BF16 — compression weights
|
||||
(Z stream for CSA, single Z for HCA)
|
||||
tail_kb: [tail_size, head_dim] BF16 — second stream (CSA only)
|
||||
tail_zb: [tail_size, head_dim] BF16 — second Z stream (CSA only)
|
||||
tail_len: scalar int32 — how many tail entries are valid
|
||||
tail_ka: [m_a, head_dim] BF16 — current a-stream tokens
|
||||
tail_za: [m_a, head_dim] BF16 — current a-stream Z weights
|
||||
tail_kb: [m_b, head_dim] BF16 — previous a-stream kept as b-input (CSA only)
|
||||
tail_zb: [m_b, head_dim] BF16 — previous Z b-stream (CSA only, init to -1e9)
|
||||
tail_len: scalar int32 — how many entries in a-stream are valid
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -49,33 +54,31 @@ class StateCachePool:
|
||||
rd = schema.rope_dim
|
||||
fp8 = hd - rd
|
||||
|
||||
# SWA window — circular within each slot. Layer's attention
|
||||
# kernel uses swa_pos to mask invalid entries.
|
||||
# SWA window — circular within each slot.
|
||||
self.swa_fp8 = torch.zeros((mr, nw, fp8), dtype=torch.uint8, device=device)
|
||||
self.swa_rope = torch.zeros((mr, nw, rd), dtype=torch.bfloat16, device=device)
|
||||
self.swa_inv = torch.ones((mr, nw), dtype=torch.float32, device=device)
|
||||
self.swa_pos = torch.full((mr, nw), -1, dtype=torch.int32, device=device)
|
||||
# Next write position within each slot's ring buffer.
|
||||
self.swa_head = torch.zeros((mr,), dtype=torch.int32, device=device)
|
||||
|
||||
# Tail buffer — only non-empty for compressed layers.
|
||||
tail = schema.tail_buffer_size
|
||||
if tail > 0:
|
||||
# For CSA we need two streams (Ca/Cb, Za/Zb) since the
|
||||
# compressor uses overlapping pairs. HCA only needs one
|
||||
# stream. Store both; HCA leaves the b-channel zero.
|
||||
self.tail_ka = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
||||
self.tail_za = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
||||
if schema.attn_type == AttentionType.CSA:
|
||||
self.tail_kb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
||||
self.tail_zb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
||||
# Tail buffer — only for compressed layers.
|
||||
m_a = schema.tail_buffer_size_a # m (CSA) or m' (HCA)
|
||||
m_b = schema.tail_buffer_size_b # m (CSA only)
|
||||
if m_a > 0:
|
||||
self.tail_ka = torch.zeros((mr, m_a, hd), dtype=torch.bfloat16, device=device)
|
||||
self.tail_za = torch.zeros((mr, m_a, hd), dtype=torch.bfloat16, device=device)
|
||||
self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device)
|
||||
if m_b > 0: # CSA: need b-stream
|
||||
self.tail_kb = torch.zeros((mr, m_b, hd), dtype=torch.bfloat16, device=device)
|
||||
# Paper §3.5.1: Z^b padded with -inf at first flush.
|
||||
# Init to -1e9 so softmax naturally masks b-stream on first flush.
|
||||
self.tail_zb = torch.full((mr, m_b, hd), -1e9, dtype=torch.bfloat16, device=device)
|
||||
else:
|
||||
self.tail_kb = None
|
||||
self.tail_zb = None
|
||||
self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device)
|
||||
else:
|
||||
self.tail_ka = self.tail_kb = None
|
||||
self.tail_za = self.tail_zb = None
|
||||
self.tail_ka = self.tail_za = None
|
||||
self.tail_kb = self.tail_zb = None
|
||||
self.tail_len = None
|
||||
|
||||
def reset_slot(self, slot: int) -> None:
|
||||
@@ -84,6 +87,9 @@ class StateCachePool:
|
||||
self.swa_head[slot] = 0
|
||||
if self.tail_len is not None:
|
||||
self.tail_len[slot] = 0
|
||||
# Re-init tail_zb to -1e9 for CSA (paper §3.5.1 first-flush mask)
|
||||
if self.tail_zb is not None:
|
||||
self.tail_zb[slot].fill_(-1e9)
|
||||
|
||||
def memory_bytes(self) -> int:
|
||||
"""Total GPU memory used by this pool."""
|
||||
|
||||
450
dsv4/kernels/cuda/flush_write.cu
Normal file
450
dsv4/kernels/cuda/flush_write.cu
Normal file
@@ -0,0 +1,450 @@
|
||||
// flush_write.cu — Quantize and scatter compressed entries into paged KV pool.
|
||||
//
|
||||
// Two kernel variants:
|
||||
// flush_write_csa_kernel: writes compressed entry + FP4 indexer key
|
||||
// flush_write_hca_kernel: writes compressed entry only (no indexer)
|
||||
//
|
||||
// Both do BF16 → FP8 (E4M3) quantization with per-token amax for the
|
||||
// non-RoPE half, and write the RoPE half as-is BF16.
|
||||
//
|
||||
// One block per request. Each block handles writing ONE compressed entry
|
||||
// per flush. At decode (B small, 1 entry/flush) this is 1-16 CTAs.
|
||||
// At prefill (B up to 128), this is up to 128 CTAs — good occupancy.
|
||||
//
|
||||
// Blackwell SM100: 128 threads per block for the FP8 quantize loop
|
||||
// covers head_dim=512 with 4 elements per thread. The FP4 indexer
|
||||
// quantize uses 64 threads (indexer_head_dim=128, 2 elements/thread).
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
// ---- Warp-level reductions ----
|
||||
|
||||
__device__ __forceinline__ float warp_reduce_max(float val) {
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
float other = __shfl_down_sync(0xffffffff, val, offset);
|
||||
val = fmaxf(val, fabsf(other));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warp_reduce_sum(float val) {
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
// ---- Block-level amax (128 threads = 4 warps) ----
|
||||
|
||||
__device__ __forceinline__ float block_reduce_amax(float val, int n_warps) {
|
||||
float warp_amax = warp_reduce_max(val);
|
||||
__shared__ float smem[4];
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
smem[threadIdx.x / 32] = warp_amax;
|
||||
}
|
||||
__syncthreads();
|
||||
float result = 0.0f;
|
||||
if (threadIdx.x < 32) {
|
||||
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f;
|
||||
result = warp_reduce_max(v);
|
||||
}
|
||||
__syncthreads();
|
||||
return result;
|
||||
}
|
||||
|
||||
// ---- NVFP4 quantization for indexer keys ----
|
||||
// 16-element groups, one E4M3 scale per group.
|
||||
// FP4 E2M1 has 6 possible values: 0, 2, 4, 6, 8, 10, 12, 14 (shifted).
|
||||
// We use a simplified approach: group amax / 6.0 -> scale,
|
||||
// quantize each element to nearest of {0,1,2,3,4,5,6} * scale.
|
||||
|
||||
__device__ __forceinline__ void quantize_fp4_group(
|
||||
const __nv_bfloat16* __restrict__ input, // 16 elements
|
||||
uint8_t* __restrict__ output, // 8 bytes (2 FP4 per byte)
|
||||
uint8_t* __restrict__ scale_out // 1 FP8 E4M3 scale
|
||||
) {
|
||||
// Compute group amax
|
||||
float amax = 0.0f;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
amax = fmaxf(amax, fabsf(__bfloat162float(input[i])));
|
||||
}
|
||||
// FP4 E2M1 has max representable = 6.0 (before scaling)
|
||||
float scale = amax / 6.0f;
|
||||
if (scale < 1e-12f) scale = 1e-12f;
|
||||
float inv_scale = scale;
|
||||
|
||||
// Write scale as FP8 E4M3
|
||||
__nv_fp8_e4m3 fp8_scale;
|
||||
fp8_scale = __nv_fp8_e4m3(scale);
|
||||
*scale_out = fp8_scale.__x;
|
||||
|
||||
// Quantize 16 elements to FP4 E2M1, pack 2 per byte
|
||||
for (int i = 0; i < 8; i++) {
|
||||
float v0 = __bfloat162float(input[2 * i]) / inv_scale;
|
||||
float v1 = __bfloat162float(input[2 * i + 1]) / inv_scale;
|
||||
// Clamp to [0, 6] and round to nearest int
|
||||
int q0 = (int)roundf(fmaxf(0.0f, fminf(6.0f, v0)));
|
||||
int q1 = (int)roundf(fmaxf(0.0f, fminf(6.0f, v1)));
|
||||
// Pack: low nibble = element 0, high nibble = element 1
|
||||
output[i] = (uint8_t)((q1 << 4) | q0);
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// CSA flush write kernel
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void flush_write_csa_kernel(
|
||||
// Inputs
|
||||
const __nv_bfloat16* __restrict__ entry, // [B, head_dim] BF16
|
||||
const __nv_bfloat16* __restrict__ indexer_key, // [B, indexer_head_dim] BF16
|
||||
const bool* __restrict__ valid_mask, // [B]
|
||||
const int32_t* __restrict__ request_slots, // [B]
|
||||
const int32_t* __restrict__ positions, // [B]
|
||||
const int32_t* __restrict__ block_table, // [B, max_logical_blocks]
|
||||
// Outputs — paged pool tensors, mutated in place
|
||||
uint8_t* __restrict__ entries_fp8, // [num_blocks, epb, fp8_dim]
|
||||
__nv_bfloat16* __restrict__ entries_rope, // [num_blocks, epb, rope_dim]
|
||||
float* __restrict__ inv_scale, // [num_blocks, epb]
|
||||
uint8_t* __restrict__ indexer_keys_fp4, // [num_blocks, epb, ihd/2]
|
||||
uint8_t* __restrict__ indexer_scale, // [num_blocks, epb, ihd/16]
|
||||
// Geometry
|
||||
int entries_per_block, int m, int rope_dim,
|
||||
int head_dim, int indexer_head_dim, int max_logical_blocks
|
||||
) {
|
||||
int b = blockIdx.x;
|
||||
if (!valid_mask[b]) return; // Early exit for no-op requests.
|
||||
|
||||
// Resolve destination slot in the paged pool.
|
||||
int pos = positions[b];
|
||||
int entry_idx = pos / m; // which compressed entry index
|
||||
int logical_block = entry_idx / entries_per_block;
|
||||
int slot_in_block = entry_idx % entries_per_block;
|
||||
int phys_block = block_table[b * max_logical_blocks + logical_block];
|
||||
|
||||
int fp8_dim = head_dim - rope_dim;
|
||||
int tid = threadIdx.x;
|
||||
int n_threads = blockDim.x; // 128
|
||||
int n_warps = n_threads / 32;
|
||||
|
||||
// ---- Step 1: Compute amax over non-RoPE half ----
|
||||
float local_amax = 0.0f;
|
||||
for (int i = tid; i < fp8_dim; i += n_threads) {
|
||||
float v = fabsf(__bfloat162float(entry[b * head_dim + i]));
|
||||
local_amax = fmaxf(local_amax, v);
|
||||
}
|
||||
float block_amax = block_reduce_amax(local_amax, n_warps);
|
||||
|
||||
// ---- Step 2: Write inv_scale ----
|
||||
__shared__ float s_inv_scale;
|
||||
if (tid == 0) {
|
||||
float scale = (block_amax > 1e-12f) ? (block_amax / 448.0f) : 1e-12f;
|
||||
s_inv_scale = scale;
|
||||
inv_scale[phys_block * entries_per_block + slot_in_block] = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Step 3: Quantize and write FP8 half ----
|
||||
float inv_s = s_inv_scale;
|
||||
for (int i = tid; i < fp8_dim; i += n_threads) {
|
||||
float v = __bfloat162float(entry[b * head_dim + i]);
|
||||
float quantized = v / inv_s;
|
||||
quantized = fmaxf(-448.0f, fminf(448.0f, quantized));
|
||||
__nv_fp8_e4m3 fp8_val;
|
||||
fp8_val = __nv_fp8_e4m3(quantized);
|
||||
entries_fp8[(phys_block * entries_per_block + slot_in_block) * fp8_dim + i] = fp8_val.__x;
|
||||
}
|
||||
|
||||
// ---- Step 4: Write BF16 RoPE half ----
|
||||
for (int i = tid; i < rope_dim; i += n_threads) {
|
||||
entries_rope[(phys_block * entries_per_block + slot_in_block) * rope_dim + i]
|
||||
= entry[b * head_dim + fp8_dim + i];
|
||||
}
|
||||
|
||||
// ---- Step 5: FP4 quantize and write indexer key ----
|
||||
// 16 elements per group, one FP8 E4M3 scale per group.
|
||||
// Process groups in parallel across threads.
|
||||
int n_groups = indexer_head_dim / 16;
|
||||
int n_bytes = indexer_head_dim / 2; // 2 FP4 per byte
|
||||
int n_scales = n_groups;
|
||||
|
||||
for (int g = tid; g < n_groups; g += n_threads) {
|
||||
// Gather 16 BF16 values for this group
|
||||
__nv_bfloat16 group_in[16];
|
||||
for (int j = 0; j < 16; j++) {
|
||||
group_in[j] = indexer_key[b * indexer_head_dim + g * 16 + j];
|
||||
}
|
||||
uint8_t group_out[8];
|
||||
uint8_t group_scale;
|
||||
quantize_fp4_group(group_in, group_out, &group_scale);
|
||||
|
||||
// Write 8 packed bytes
|
||||
int byte_offset = (phys_block * entries_per_block + slot_in_block) * n_bytes + g * 8;
|
||||
for (int j = 0; j < 8; j++) {
|
||||
indexer_keys_fp4[byte_offset + j] = group_out[j];
|
||||
}
|
||||
// Write scale
|
||||
int scale_offset = (phys_block * entries_per_block + slot_in_block) * n_scales + g;
|
||||
indexer_scale[scale_offset] = group_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// HCA flush write kernel (no indexer)
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void flush_write_hca_kernel(
|
||||
const __nv_bfloat16* __restrict__ entry,
|
||||
const bool* __restrict__ valid_mask,
|
||||
const int32_t* __restrict__ request_slots,
|
||||
const int32_t* __restrict__ positions,
|
||||
const int32_t* __restrict__ block_table,
|
||||
uint8_t* __restrict__ entries_fp8,
|
||||
__nv_bfloat16* __restrict__ entries_rope,
|
||||
float* __restrict__ inv_scale,
|
||||
int entries_per_block, int m, int rope_dim,
|
||||
int head_dim, int max_logical_blocks
|
||||
) {
|
||||
int b = blockIdx.x;
|
||||
if (!valid_mask[b]) return;
|
||||
|
||||
int pos = positions[b];
|
||||
int entry_idx = pos / m;
|
||||
int logical_block = entry_idx / entries_per_block;
|
||||
int slot_in_block = entry_idx % entries_per_block;
|
||||
int phys_block = block_table[b * max_logical_blocks + logical_block];
|
||||
|
||||
int fp8_dim = head_dim - rope_dim;
|
||||
int tid = threadIdx.x;
|
||||
int n_threads = blockDim.x;
|
||||
int n_warps = n_threads / 32;
|
||||
|
||||
// Amax reduction
|
||||
float local_amax = 0.0f;
|
||||
for (int i = tid; i < fp8_dim; i += n_threads) {
|
||||
float v = fabsf(__bfloat162float(entry[b * head_dim + i]));
|
||||
local_amax = fmaxf(local_amax, v);
|
||||
}
|
||||
float block_amax = block_reduce_amax(local_amax, n_warps);
|
||||
|
||||
__shared__ float s_inv_scale;
|
||||
if (tid == 0) {
|
||||
float scale = (block_amax > 1e-12f) ? (block_amax / 448.0f) : 1e-12f;
|
||||
s_inv_scale = scale;
|
||||
inv_scale[phys_block * entries_per_block + slot_in_block] = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// FP8 quantize + write
|
||||
float inv_s = s_inv_scale;
|
||||
for (int i = tid; i < fp8_dim; i += n_threads) {
|
||||
float v = __bfloat162float(entry[b * head_dim + i]);
|
||||
float quantized = v / inv_s;
|
||||
quantized = fmaxf(-448.0f, fminf(448.0f, quantized));
|
||||
__nv_fp8_e4m3 fp8_val;
|
||||
fp8_val = __nv_fp8_e4m3(quantized);
|
||||
entries_fp8[(phys_block * entries_per_block + slot_in_block) * fp8_dim + i] = fp8_val.__x;
|
||||
}
|
||||
|
||||
// BF16 RoPE half
|
||||
for (int i = tid; i < rope_dim; i += n_threads) {
|
||||
entries_rope[(phys_block * entries_per_block + slot_in_block) * rope_dim + i]
|
||||
= entry[b * head_dim + fp8_dim + i];
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// State rotation kernels (in-place, single-kernel launches)
|
||||
// ===========================================================================
|
||||
|
||||
// CSA: after flush, rotate a-stream -> b-stream, clear a-stream
|
||||
__global__ void csa_rotate_state_kernel(
|
||||
const bool* __restrict__ valid_mask, // [B]
|
||||
const int32_t* __restrict__ request_slots, // [B]
|
||||
// State cache tensors — mutated in place
|
||||
__nv_bfloat16* __restrict__ tail_ka, // [max_req, m, head_dim]
|
||||
__nv_bfloat16* __restrict__ tail_za,
|
||||
__nv_bfloat16* __restrict__ tail_kb,
|
||||
__nv_bfloat16* __restrict__ tail_zb,
|
||||
int32_t* __restrict__ tail_len, // [max_req]
|
||||
int m, int head_dim, int max_requests
|
||||
) {
|
||||
int b = blockIdx.x;
|
||||
if (!valid_mask[b]) return;
|
||||
|
||||
int slot = request_slots[b];
|
||||
int tid = threadIdx.x;
|
||||
int n_threads = blockDim.x;
|
||||
|
||||
// Rotate: kb <- ka, zb <- za (current a-stream becomes next b-stream)
|
||||
int total = m * head_dim;
|
||||
for (int i = tid; i < total; i += n_threads) {
|
||||
tail_kb[slot * total + i] = tail_ka[slot * total + i];
|
||||
tail_zb[slot * total + i] = tail_za[slot * total + i];
|
||||
}
|
||||
|
||||
// Clear a-stream (zero out) and reset tail_len
|
||||
if (tid == 0) {
|
||||
tail_len[slot] = 0;
|
||||
}
|
||||
for (int i = tid; i < total; i += n_threads) {
|
||||
tail_ka[slot * total + i] = __float2bfloat16(0.0f);
|
||||
tail_za[slot * total + i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// HCA: after flush, just clear a-stream and reset tail_len
|
||||
__global__ void hca_reset_state_kernel(
|
||||
const bool* __restrict__ valid_mask,
|
||||
const int32_t* __restrict__ request_slots,
|
||||
__nv_bfloat16* __restrict__ tail_ka,
|
||||
__nv_bfloat16* __restrict__ tail_za,
|
||||
int32_t* __restrict__ tail_len,
|
||||
int m, int head_dim, int max_requests
|
||||
) {
|
||||
int b = blockIdx.x;
|
||||
if (!valid_mask[b]) return;
|
||||
|
||||
int slot = request_slots[b];
|
||||
int tid = threadIdx.x;
|
||||
int n_threads = blockDim.x;
|
||||
|
||||
int total = m * head_dim;
|
||||
if (tid == 0) {
|
||||
tail_len[slot] = 0;
|
||||
}
|
||||
for (int i = tid; i < total; i += n_threads) {
|
||||
tail_ka[slot * total + i] = __float2bfloat16(0.0f);
|
||||
tail_za[slot * total + i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ===========================================================================
|
||||
// PyTorch bindings
|
||||
// ===========================================================================
|
||||
|
||||
void flush_write_csa_cuda(
|
||||
torch::Tensor entry,
|
||||
torch::Tensor indexer_key,
|
||||
torch::Tensor valid_mask,
|
||||
torch::Tensor request_slots,
|
||||
torch::Tensor positions,
|
||||
torch::Tensor block_table,
|
||||
torch::Tensor entries_fp8,
|
||||
torch::Tensor entries_rope,
|
||||
torch::Tensor inv_scale,
|
||||
torch::Tensor indexer_keys_fp4,
|
||||
torch::Tensor indexer_scale,
|
||||
int64_t entries_per_block, int64_t m, int64_t rope_dim,
|
||||
int64_t head_dim, int64_t indexer_head_dim
|
||||
) {
|
||||
int B = entry.size(0);
|
||||
int max_logical_blocks = block_table.size(1);
|
||||
int threads = 128;
|
||||
flush_write_csa_kernel<<<B, threads>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(entry.data_ptr<at::BFloat16>()),
|
||||
reinterpret_cast<const __nv_bfloat16*>(indexer_key.data_ptr<at::BFloat16>()),
|
||||
valid_mask.data_ptr<bool>(),
|
||||
request_slots.data_ptr<int32_t>(),
|
||||
positions.data_ptr<int32_t>(),
|
||||
block_table.data_ptr<int32_t>(),
|
||||
entries_fp8.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(entries_rope.data_ptr<at::BFloat16>()),
|
||||
inv_scale.data_ptr<float>(),
|
||||
indexer_keys_fp4.data_ptr<uint8_t>(),
|
||||
indexer_scale.data_ptr<uint8_t>(),
|
||||
(int)entries_per_block, (int)m, (int)rope_dim,
|
||||
(int)head_dim, (int)indexer_head_dim, max_logical_blocks
|
||||
);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void flush_write_hca_cuda(
|
||||
torch::Tensor entry,
|
||||
torch::Tensor valid_mask,
|
||||
torch::Tensor request_slots,
|
||||
torch::Tensor positions,
|
||||
torch::Tensor block_table,
|
||||
torch::Tensor entries_fp8,
|
||||
torch::Tensor entries_rope,
|
||||
torch::Tensor inv_scale,
|
||||
int64_t entries_per_block, int64_t m, int64_t rope_dim,
|
||||
int64_t head_dim
|
||||
) {
|
||||
int B = entry.size(0);
|
||||
int max_logical_blocks = block_table.size(1);
|
||||
int threads = 128;
|
||||
flush_write_hca_kernel<<<B, threads>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(entry.data_ptr<at::BFloat16>()),
|
||||
valid_mask.data_ptr<bool>(),
|
||||
request_slots.data_ptr<int32_t>(),
|
||||
positions.data_ptr<int32_t>(),
|
||||
block_table.data_ptr<int32_t>(),
|
||||
entries_fp8.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(entries_rope.data_ptr<at::BFloat16>()),
|
||||
inv_scale.data_ptr<float>(),
|
||||
(int)entries_per_block, (int)m, (int)rope_dim,
|
||||
(int)head_dim, max_logical_blocks
|
||||
);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void csa_rotate_state_cuda(
|
||||
torch::Tensor valid_mask,
|
||||
torch::Tensor request_slots,
|
||||
torch::Tensor tail_ka,
|
||||
torch::Tensor tail_za,
|
||||
torch::Tensor tail_kb,
|
||||
torch::Tensor tail_zb,
|
||||
torch::Tensor tail_len,
|
||||
int64_t m, int64_t head_dim
|
||||
) {
|
||||
int B = valid_mask.size(0);
|
||||
int threads = 128;
|
||||
csa_rotate_state_kernel<<<B, threads>>>(
|
||||
valid_mask.data_ptr<bool>(),
|
||||
request_slots.data_ptr<int32_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(tail_ka.data_ptr<at::BFloat16>()),
|
||||
reinterpret_cast<__nv_bfloat16*>(tail_za.data_ptr<at::BFloat16>()),
|
||||
reinterpret_cast<__nv_bfloat16*>(tail_kb.data_ptr<at::BFloat16>()),
|
||||
reinterpret_cast<__nv_bfloat16*>(tail_zb.data_ptr<at::BFloat16>()),
|
||||
tail_len.data_ptr<int32_t>(),
|
||||
(int)m, (int)head_dim, 0 // max_requests unused in kernel
|
||||
);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void hca_reset_state_cuda(
|
||||
torch::Tensor valid_mask,
|
||||
torch::Tensor request_slots,
|
||||
torch::Tensor tail_ka,
|
||||
torch::Tensor tail_za,
|
||||
torch::Tensor tail_len,
|
||||
int64_t m, int64_t head_dim
|
||||
) {
|
||||
int B = valid_mask.size(0);
|
||||
int threads = 128;
|
||||
hca_reset_state_kernel<<<B, threads>>>(
|
||||
valid_mask.data_ptr<bool>(),
|
||||
request_slots.data_ptr<int32_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(tail_ka.data_ptr<at::BFloat16>()),
|
||||
reinterpret_cast<__nv_bfloat16*>(tail_za.data_ptr<at::BFloat16>()),
|
||||
tail_len.data_ptr<int32_t>(),
|
||||
(int)m, (int)head_dim, 0
|
||||
);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("flush_write_csa", &flush_write_csa_cuda, "CSA flush write kernel");
|
||||
m.def("flush_write_hca", &flush_write_hca_cuda, "HCA flush write kernel");
|
||||
m.def("csa_rotate_state", &csa_rotate_state_cuda, "CSA state rotation kernel");
|
||||
m.def("hca_reset_state", &hca_reset_state_cuda, "HCA state reset kernel");
|
||||
}
|
||||
Reference in New Issue
Block a user