From faf92b30adeb25f40f21013cf8240306d108bb31 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 21:09:21 +0000 Subject: [PATCH] E1: Wire LayerCacheHandle gather methods + CUDA gather kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - gather_compressed_kv: CSA top-k gather via existing gather_kv.cu - gather_all_compressed_kv: HCA dense gather via new gather_all_compressed_kernel - gather_swa_kv: SWA ring buffer gather via new gather_swa_kernel - Added gather_swa.cu with both SWA + all-compressed gather kernels - Added gather.py Python wrapper (torch.utils.cpp_extension JIT) - Updated handle.py: added schema field, num_query_heads/head_dim properties - Updated manager.py: passes schema + num_query_heads to handle All gather kernels: FP8→BF16 dequant + BF16 RoPE concat in single launch. Output: dense BF16 tensors ready for FMHA consumption. --- dsv4/cache/handle.py | 146 +++++++++++++++++++++++++- dsv4/cache/manager.py | 5 +- dsv4/kernels/cache/gather.py | 115 +++++++++++++++++++++ dsv4/kernels/cuda/gather_swa.cu | 177 ++++++++++++++++++++++++++++++++ 4 files changed, 441 insertions(+), 2 deletions(-) create mode 100644 dsv4/kernels/cache/gather.py create mode 100644 dsv4/kernels/cuda/gather_swa.cu diff --git a/dsv4/cache/handle.py b/dsv4/cache/handle.py index a638cfc9..6ef76903 100644 --- a/dsv4/cache/handle.py +++ b/dsv4/cache/handle.py @@ -13,6 +13,7 @@ import torch if TYPE_CHECKING: from dsv4.cache.paged_cache import PagedKVPool from dsv4.cache.state_cache import StateCachePool + from dsv4.cache.schema import LayerCacheSchema @dataclass @@ -25,6 +26,7 @@ class LayerCacheHandle: # Pool references (shared across handles — never mutated). paged: Optional["PagedKVPool"] state: "StateCachePool" + schema: "LayerCacheSchema" # Per-call indices. request_slots: torch.Tensor # [batch] int32 — state cache slot per request @@ -37,9 +39,30 @@ class LayerCacheHandle: # Number of valid blocks per request (excludes padding). block_lens: Optional[torch.Tensor] + # ------------------------------------------------------------------ + # Properties called by AttentionSubBlock + # ------------------------------------------------------------------ + + @property + def num_query_heads(self) -> int: + """Number of query heads (from schema).""" + # The schema doesn't store n_q directly — derive from the config. + # For now, store on the handle at construction. + return self._num_query_heads + + @num_query_heads.setter + def num_query_heads(self, value: int): + self._num_query_heads = value + + @property + def head_dim(self) -> int: + """Head dimension (from schema).""" + return self.schema.entry_head_dim + # ------------------------------------------------------------------ # Methods called by AttentionSubBlock # ------------------------------------------------------------------ + def write_swa( self, raw_kv: torch.Tensor, # (T, head_dim) BF16 @@ -59,7 +82,7 @@ class LayerCacheHandle: swa_inv=self.state.swa_inv, swa_pos=self.state.swa_pos, swa_head=self.state.swa_head, - rope_dim=self.state.schema.rope_dim, + rope_dim=self.schema.rope_dim, ) def flush_compression( @@ -78,6 +101,122 @@ class LayerCacheHandle: """ raise NotImplementedError("see kernels/cache/flush_compression.py") + def gather_compressed_kv( + self, + selected_indices: torch.Tensor, # (T, top_k) int64 — from indexer + ) -> tuple[torch.Tensor, torch.Tensor]: + """CSA: gather top-k compressed KV entries into dense BF16 tensors. + + Returns: + (k_compressed, v_compressed) each of shape (1, n_comp, head_dim) BF16. + The leading dim=1 is for the single KV head (MQA in DSV4). + """ + assert self.paged is not None, "CSA gather requires paged pool" + from dsv4.kernels.cache.gather import gather_compressed_kv + + hd = self.head_dim + rd = self.schema.rope_dim + epb = self.schema.entries_per_block + + # selected_indices is int64, gather kernel needs int32 + indices_i32 = selected_indices.to(torch.int32) + + # block_table for CSA: [batch, max_logical_blocks] + # For per-request gather, use the first request's block_table + # (decode: batch=1, so this is trivial) + if self.block_table.dim() == 1: + bt = self.block_table.unsqueeze(0) + else: + bt = self.block_table + + k_out = gather_compressed_kv( + entries_fp8=self.paged.entries_fp8, + entries_rope=self.paged.entries_rope, + inv_scale=self.paged.inv_scale, + topk_indices=indices_i32, + block_table=bt, + entries_per_block=epb, + head_dim=hd, + rope_dim=rd, + ) + # k_out: (T, top_k, hd) — for FMHA we need (1, n_comp, hd) + # At decode T=1: squeeze to (top_k, hd) then unsqueeze for KV head dim + n_comp = k_out.shape[1] + k_compressed = k_out.squeeze(0).unsqueeze(0) # (1, n_comp, hd) + # V shares the same storage but is transposed — DSV4 uses K=V for + # the compressed KV (same entries, different projection weights applied + # before compression). For now, return the same gathered tensor. + # TODO: verify if K and V are stored separately or shared. + v_compressed = k_compressed.clone() + return k_compressed, v_compressed + + def gather_all_compressed_kv(self) -> tuple[torch.Tensor, torch.Tensor]: + """HCA: gather ALL compressed KV entries into dense BF16 tensors. + + No indexer — dense attention over the short compressed sequence. + + Returns: + (k_compressed, v_compressed) each of shape (1, n_comp, head_dim) BF16. + """ + assert self.paged is not None, "HCA gather requires paged pool" + from dsv4.kernels.cache.gather import gather_all_compressed_kv + + hd = self.head_dim + rd = self.schema.rope_dim + epb = self.schema.entries_per_block + + if self.block_table.dim() == 1: + bt = self.block_table.unsqueeze(0) + bl = self.block_lens.unsqueeze(0) if self.block_lens is not None else None + else: + bt = self.block_table + bl = self.block_lens + + if bl is None: + # Default: all blocks valid + bl = torch.full((bt.shape[0],), bt.shape[1], dtype=torch.int32, device=bt.device) + + k_out = gather_all_compressed_kv( + entries_fp8=self.paged.entries_fp8, + entries_rope=self.paged.entries_rope, + inv_scale=self.paged.inv_scale, + block_table=bt, + block_lens=bl, + entries_per_block=epb, + head_dim=hd, + rope_dim=rd, + ) + # k_out: (batch, total_entries, hd) — for FMHA we need (1, n_comp, hd) + n_comp = k_out.shape[1] + k_compressed = k_out.squeeze(0).unsqueeze(0) # (1, n_comp, hd) + v_compressed = k_compressed.clone() + return k_compressed, v_compressed + + def gather_swa_kv(self) -> tuple[torch.Tensor, torch.Tensor]: + """Gather SWA window entries into dense BF16 tensors. + + Returns: + (k_swa, v_swa) each of shape (1, swa_len, head_dim) BF16. + """ + from dsv4.kernels.cache.gather import gather_swa_kv + + hd = self.head_dim + rd = self.schema.rope_dim + + k_out = gather_swa_kv( + swa_fp8=self.state.swa_fp8, + swa_rope=self.state.swa_rope, + swa_inv=self.state.swa_inv, + swa_pos=self.state.swa_pos, + request_slots=self.request_slots, + head_dim=hd, + rope_dim=rd, + ) + # k_out: (batch, n_win, hd) — for FMHA we need (1, swa_len, hd) + k_swa = k_out.squeeze(0).unsqueeze(0) # (1, swa_len, hd) + v_swa = k_swa.clone() + return k_swa, v_swa + def read_swa_view(self) -> "SWAView": """Return a typed view of the SWA window for this batch.""" return SWAView( @@ -111,6 +250,11 @@ class LayerCacheHandle: block_lens=self.block_lens, ) + def __post_init__(self): + # Initialize _num_query_heads (must be set by the manager at construction) + if not hasattr(self, '_num_query_heads'): + self._num_query_heads = 0 + # Typed views — simple dataclasses, no logic. The FMHA / indexer / SWA # kernels accept these to keep their signatures clean. diff --git a/dsv4/cache/manager.py b/dsv4/cache/manager.py index 8af0093d..20b82725 100644 --- a/dsv4/cache/manager.py +++ b/dsv4/cache/manager.py @@ -173,15 +173,18 @@ class KVCacheManager: block_table = None block_lens = None - return LayerCacheHandle( + handle = LayerCacheHandle( paged=paged, state=state, + schema=self.schemas[layer_idx], request_slots=request_slots, positions=positions, request_ids=request_ids, block_table=block_table, block_lens=block_lens, ) + handle.num_query_heads = self.config.num_query_heads + return handle # ------------------------------------------------------------------ # Diagnostics diff --git a/dsv4/kernels/cache/gather.py b/dsv4/kernels/cache/gather.py new file mode 100644 index 00000000..e2122bee --- /dev/null +++ b/dsv4/kernels/cache/gather.py @@ -0,0 +1,115 @@ +"""Python wrappers for cache gather kernels. + +Provides the bridge between LayerCacheHandle and the raw CUDA gather ops. +Loaded via torch.utils.cpp_extension (JIT, sm_100a). + +Three gather paths: +1. gather_compressed_kv: CSA — top-k selected compressed entries +2. gather_all_compressed_kv: HCA — all compressed entries (no indexer) +3. gather_swa_kv: SWA window entries from state cache + +All outputs are dense BF16 tensors ready for FMHA consumption. +""" +import os +import torch +from torch.utils.cpp_extension import load + +_kernel_module = None +_kernel_dir = os.path.join(os.path.dirname(__file__), "cuda") + + +def _get_kernel_module(): + global _kernel_module + if _kernel_module is not None: + return _kernel_module + _kernel_module = load( + name="cache_gather", + sources=[ + os.path.join(_kernel_dir, "gather_kv.cu"), + os.path.join(_kernel_dir, "gather_swa.cu"), + ], + extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"], + verbose=False, + ) + return _kernel_module + + +def gather_compressed_kv( + entries_fp8: torch.Tensor, # [num_blocks, epb, fp8_dim] uint8 + entries_rope: torch.Tensor, # [num_blocks, epb, rope_dim] BF16 + inv_scale: torch.Tensor, # [num_blocks, epb] FP32 + topk_indices: torch.Tensor, # [T, top_k] int32 + block_table: torch.Tensor, # [batch, max_logical_blocks] int32 + entries_per_block: int, + head_dim: int, + rope_dim: int, +) -> torch.Tensor: + """Gather top-k compressed KV entries into a dense BF16 tile. + + Returns: (T, top_k, head_dim) BF16 + """ + T = topk_indices.size(0) + top_k = topk_indices.size(1) + output = torch.zeros(T, top_k, head_dim, dtype=torch.bfloat16, device=entries_fp8.device) + + mod = _get_kernel_module() + mod.gather_kv( + entries_fp8, entries_rope, inv_scale, + topk_indices, block_table, output, + entries_per_block, rope_dim, + ) + return output + + +def gather_all_compressed_kv( + entries_fp8: torch.Tensor, # [num_blocks, epb, fp8_dim] uint8 + entries_rope: torch.Tensor, # [num_blocks, epb, rope_dim] BF16 + inv_scale: torch.Tensor, # [num_blocks, epb] FP32 + block_table: torch.Tensor, # [batch, max_logical_blocks] int32 + block_lens: torch.Tensor, # [batch] int32 + entries_per_block: int, + head_dim: int, + rope_dim: int, +) -> torch.Tensor: + """Gather ALL compressed KV entries for HCA (dense attention). + + Returns: (batch, total_entries, head_dim) BF16 where + total_entries = block_lens.sum() * entries_per_block (padded to max) + """ + batch = block_table.size(0) + max_blocks = block_lens.max().item() + total_entries = max_blocks * entries_per_block + output = torch.zeros(batch, total_entries, head_dim, dtype=torch.bfloat16, device=entries_fp8.device) + + mod = _get_kernel_module() + mod.gather_all_compressed( + entries_fp8, entries_rope, inv_scale, + block_table, block_lens, output, + entries_per_block, rope_dim, + ) + return output + + +def gather_swa_kv( + swa_fp8: torch.Tensor, # [max_req, n_win, fp8_dim] uint8 + swa_rope: torch.Tensor, # [max_req, n_win, rope_dim] BF16 + swa_inv: torch.Tensor, # [max_req, n_win] FP32 + swa_pos: torch.Tensor, # [max_req, n_win] int32 + request_slots: torch.Tensor, # [batch] int32 + head_dim: int, + rope_dim: int, +) -> torch.Tensor: + """Gather SWA window entries into a dense BF16 tile. + + Returns: (batch, n_win, head_dim) BF16 + """ + batch = request_slots.size(0) + n_win = swa_fp8.size(1) + output = torch.zeros(batch, n_win, head_dim, dtype=torch.bfloat16, device=swa_fp8.device) + + mod = _get_kernel_module() + mod.gather_swa( + swa_fp8, swa_rope, swa_inv, swa_pos, request_slots, + output, head_dim, rope_dim, + ) + return output diff --git a/dsv4/kernels/cuda/gather_swa.cu b/dsv4/kernels/cuda/gather_swa.cu new file mode 100644 index 00000000..d34a18a4 --- /dev/null +++ b/dsv4/kernels/cuda/gather_swa.cu @@ -0,0 +1,177 @@ +// gather_swa.cu — Gather SWA window entries into a dense BF16 tile. +// +// Reads from the state cache's SWA ring buffer (FP8 + BF16 split layout). +// One CTA per request. Each CTA iterates over the n_win positions in the +// ring buffer, dequantizes FP8 → BF16, concatenates the RoPE half, and +// writes to the dense output tensor. +// +// Output shape: [batch, n_win, head_dim] BF16 — consumed by FMHA. +// Positions with swa_pos == -1 (unused slots) are zero-filled. + +#include +#include +#include +#include +#include + + +__global__ void gather_swa_kernel( + // Inputs + const uint8_t* __restrict__ swa_fp8, // [max_req, n_win, fp8_dim] + const __nv_bfloat16* __restrict__ swa_rope, // [max_req, n_win, rope_dim] + const float* __restrict__ swa_inv, // [max_req, n_win] + const int32_t* __restrict__ swa_pos, // [max_req, n_win] + const int32_t* __restrict__ request_slots, // [batch] — state cache slot per request + // Output + __nv_bfloat16* __restrict__ output, // [batch, n_win, head_dim] BF16 + // Geometry + int batch, int n_win, int head_dim, int rope_dim, int max_requests +) { + int fp8_dim = head_dim - rope_dim; + int b = blockIdx.x; + if (b >= batch) return; + + int slot = request_slots[b]; + + for (int w = 0; w < n_win; w++) { + int pos = swa_pos[slot * n_win + w]; + int out_row = b * n_win * head_dim + w * head_dim; + + if (pos < 0) { + // Unused slot — zero fill + for (int d = threadIdx.x; d < head_dim; d += blockDim.x) { + output[out_row + d] = __float2bfloat16(0.0f); + } + continue; + } + + float s = swa_inv[slot * n_win + w]; + + // Dequantize FP8 half + for (int d = threadIdx.x; d < fp8_dim; d += blockDim.x) { + uint8_t raw = swa_fp8[(slot * n_win + w) * fp8_dim + d]; + __nv_fp8_e4m3 fp8_val; + fp8_val.__x = raw; + float dequant = (float)fp8_val * s; + output[out_row + d] = __float2bfloat16(dequant); + } + + // Copy BF16 RoPE half + for (int d = threadIdx.x; d < rope_dim; d += blockDim.x) { + output[out_row + fp8_dim + d] = swa_rope[(slot * n_win + w) * rope_dim + d]; + } + } +} + + +void gather_swa_cuda( + torch::Tensor swa_fp8, + torch::Tensor swa_rope, + torch::Tensor swa_inv, + torch::Tensor swa_pos, + torch::Tensor request_slots, + torch::Tensor output, + int64_t head_dim, int64_t rope_dim +) { + int batch = request_slots.size(0); + int n_win = swa_fp8.size(1); + int max_requests = swa_fp8.size(0); + + int threads = 128; + gather_swa_kernel<<>>( + swa_fp8.data_ptr(), + reinterpret_cast(swa_rope.data_ptr()), + swa_inv.data_ptr(), + swa_pos.data_ptr(), + request_slots.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), + batch, n_win, (int)head_dim, (int)rope_dim, max_requests + ); + C10_CUDA_CHECK(cudaGetLastError()); +} + + +// gather_all_compressed_kv_kernel — Gather ALL compressed entries for HCA +// (no top-k, dense attention over the short compressed sequence). +// +// One CTA per request. Iterates over all valid blocks in the block_table, +// dequantizes all entries, writes to a dense output. +// Output: [batch, total_entries, head_dim] BF16 + +__global__ void gather_all_compressed_kernel( + const uint8_t* __restrict__ entries_fp8, + const __nv_bfloat16* __restrict__ entries_rope, + const float* __restrict__ inv_scale, + const int32_t* __restrict__ block_table, // [batch, max_logical_blocks] + const int32_t* __restrict__ block_lens, // [batch] — valid blocks per request + __nv_bfloat16* __restrict__ output, // [batch, total_entries, head_dim] + int batch, int entries_per_block, int head_dim, + int rope_dim, int max_logical_blocks +) { + int fp8_dim = head_dim - rope_dim; + int b = blockIdx.x; + if (b >= batch) return; + + int n_blocks = block_lens[b]; + int out_idx = 0; + + for (int lb = 0; lb < n_blocks; lb++) { + int phys_block = block_table[b * max_logical_blocks + lb]; + if (phys_block < 0) continue; + + for (int epb = 0; epb < entries_per_block; epb++) { + int block_entry = phys_block * entries_per_block + epb; + float s = inv_scale[block_entry]; + int out_row = (b * n_blocks * entries_per_block + out_idx) * head_dim; + + // Dequantize FP8 half + for (int d = threadIdx.x; d < fp8_dim; d += blockDim.x) { + uint8_t raw = entries_fp8[block_entry * fp8_dim + d]; + __nv_fp8_e4m3 fp8_val; + fp8_val.__x = raw; + float dequant = (float)fp8_val * s; + output[out_row + d] = __float2bfloat16(dequant); + } + + // Copy BF16 RoPE half + for (int d = threadIdx.x; d < rope_dim; d += blockDim.x) { + output[out_row + fp8_dim + d] = entries_rope[block_entry * rope_dim + d]; + } + out_idx++; + } + } +} + + +void gather_all_compressed_cuda( + torch::Tensor entries_fp8, + torch::Tensor entries_rope, + torch::Tensor inv_scale, + torch::Tensor block_table, + torch::Tensor block_lens, + torch::Tensor output, + int64_t entries_per_block, int64_t rope_dim +) { + int batch = block_table.size(0); + int head_dim = entries_fp8.size(2) + entries_rope.size(2); + int max_logical_blocks = block_table.size(1); + + int threads = 128; + gather_all_compressed_kernel<<>>( + entries_fp8.data_ptr(), + reinterpret_cast(entries_rope.data_ptr()), + inv_scale.data_ptr(), + block_table.data_ptr(), + block_lens.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), + batch, (int)entries_per_block, (int)head_dim, + (int)rope_dim, max_logical_blocks + ); + C10_CUDA_CHECK(cudaGetLastError()); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("gather_swa", &gather_swa_cuda, "Gather SWA window into dense BF16 tile"); + m.def("gather_all_compressed", &gather_all_compressed_cuda, "Gather all compressed KV for HCA"); +}