E1: Wire LayerCacheHandle gather methods + CUDA gather kernels

- gather_compressed_kv: CSA top-k gather via existing gather_kv.cu
- gather_all_compressed_kv: HCA dense gather via new gather_all_compressed_kernel
- gather_swa_kv: SWA ring buffer gather via new gather_swa_kernel
- Added gather_swa.cu with both SWA + all-compressed gather kernels
- Added gather.py Python wrapper (torch.utils.cpp_extension JIT)
- Updated handle.py: added schema field, num_query_heads/head_dim properties
- Updated manager.py: passes schema + num_query_heads to handle

All gather kernels: FP8→BF16 dequant + BF16 RoPE concat in single launch.
Output: dense BF16 tensors ready for FMHA consumption.
This commit is contained in:
2026-05-30 21:09:21 +00:00
parent 4b9eed02e1
commit faf92b30ad
4 changed files with 441 additions and 2 deletions

146
dsv4/cache/handle.py vendored
View File

@@ -13,6 +13,7 @@ import torch
if TYPE_CHECKING:
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.state_cache import StateCachePool
from dsv4.cache.schema import LayerCacheSchema
@dataclass
@@ -25,6 +26,7 @@ class LayerCacheHandle:
# Pool references (shared across handles — never mutated).
paged: Optional["PagedKVPool"]
state: "StateCachePool"
schema: "LayerCacheSchema"
# Per-call indices.
request_slots: torch.Tensor # [batch] int32 — state cache slot per request
@@ -37,9 +39,30 @@ class LayerCacheHandle:
# Number of valid blocks per request (excludes padding).
block_lens: Optional[torch.Tensor]
# ------------------------------------------------------------------
# Properties called by AttentionSubBlock
# ------------------------------------------------------------------
@property
def num_query_heads(self) -> int:
"""Number of query heads (from schema)."""
# The schema doesn't store n_q directly — derive from the config.
# For now, store on the handle at construction.
return self._num_query_heads
@num_query_heads.setter
def num_query_heads(self, value: int):
self._num_query_heads = value
@property
def head_dim(self) -> int:
"""Head dimension (from schema)."""
return self.schema.entry_head_dim
# ------------------------------------------------------------------
# Methods called by AttentionSubBlock
# ------------------------------------------------------------------
def write_swa(
self,
raw_kv: torch.Tensor, # (T, head_dim) BF16
@@ -59,7 +82,7 @@ class LayerCacheHandle:
swa_inv=self.state.swa_inv,
swa_pos=self.state.swa_pos,
swa_head=self.state.swa_head,
rope_dim=self.state.schema.rope_dim,
rope_dim=self.schema.rope_dim,
)
def flush_compression(
@@ -78,6 +101,122 @@ class LayerCacheHandle:
"""
raise NotImplementedError("see kernels/cache/flush_compression.py")
def gather_compressed_kv(
self,
selected_indices: torch.Tensor, # (T, top_k) int64 — from indexer
) -> tuple[torch.Tensor, torch.Tensor]:
"""CSA: gather top-k compressed KV entries into dense BF16 tensors.
Returns:
(k_compressed, v_compressed) each of shape (1, n_comp, head_dim) BF16.
The leading dim=1 is for the single KV head (MQA in DSV4).
"""
assert self.paged is not None, "CSA gather requires paged pool"
from dsv4.kernels.cache.gather import gather_compressed_kv
hd = self.head_dim
rd = self.schema.rope_dim
epb = self.schema.entries_per_block
# selected_indices is int64, gather kernel needs int32
indices_i32 = selected_indices.to(torch.int32)
# block_table for CSA: [batch, max_logical_blocks]
# For per-request gather, use the first request's block_table
# (decode: batch=1, so this is trivial)
if self.block_table.dim() == 1:
bt = self.block_table.unsqueeze(0)
else:
bt = self.block_table
k_out = gather_compressed_kv(
entries_fp8=self.paged.entries_fp8,
entries_rope=self.paged.entries_rope,
inv_scale=self.paged.inv_scale,
topk_indices=indices_i32,
block_table=bt,
entries_per_block=epb,
head_dim=hd,
rope_dim=rd,
)
# k_out: (T, top_k, hd) — for FMHA we need (1, n_comp, hd)
# At decode T=1: squeeze to (top_k, hd) then unsqueeze for KV head dim
n_comp = k_out.shape[1]
k_compressed = k_out.squeeze(0).unsqueeze(0) # (1, n_comp, hd)
# V shares the same storage but is transposed — DSV4 uses K=V for
# the compressed KV (same entries, different projection weights applied
# before compression). For now, return the same gathered tensor.
# TODO: verify if K and V are stored separately or shared.
v_compressed = k_compressed.clone()
return k_compressed, v_compressed
def gather_all_compressed_kv(self) -> tuple[torch.Tensor, torch.Tensor]:
"""HCA: gather ALL compressed KV entries into dense BF16 tensors.
No indexer — dense attention over the short compressed sequence.
Returns:
(k_compressed, v_compressed) each of shape (1, n_comp, head_dim) BF16.
"""
assert self.paged is not None, "HCA gather requires paged pool"
from dsv4.kernels.cache.gather import gather_all_compressed_kv
hd = self.head_dim
rd = self.schema.rope_dim
epb = self.schema.entries_per_block
if self.block_table.dim() == 1:
bt = self.block_table.unsqueeze(0)
bl = self.block_lens.unsqueeze(0) if self.block_lens is not None else None
else:
bt = self.block_table
bl = self.block_lens
if bl is None:
# Default: all blocks valid
bl = torch.full((bt.shape[0],), bt.shape[1], dtype=torch.int32, device=bt.device)
k_out = gather_all_compressed_kv(
entries_fp8=self.paged.entries_fp8,
entries_rope=self.paged.entries_rope,
inv_scale=self.paged.inv_scale,
block_table=bt,
block_lens=bl,
entries_per_block=epb,
head_dim=hd,
rope_dim=rd,
)
# k_out: (batch, total_entries, hd) — for FMHA we need (1, n_comp, hd)
n_comp = k_out.shape[1]
k_compressed = k_out.squeeze(0).unsqueeze(0) # (1, n_comp, hd)
v_compressed = k_compressed.clone()
return k_compressed, v_compressed
def gather_swa_kv(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Gather SWA window entries into dense BF16 tensors.
Returns:
(k_swa, v_swa) each of shape (1, swa_len, head_dim) BF16.
"""
from dsv4.kernels.cache.gather import gather_swa_kv
hd = self.head_dim
rd = self.schema.rope_dim
k_out = gather_swa_kv(
swa_fp8=self.state.swa_fp8,
swa_rope=self.state.swa_rope,
swa_inv=self.state.swa_inv,
swa_pos=self.state.swa_pos,
request_slots=self.request_slots,
head_dim=hd,
rope_dim=rd,
)
# k_out: (batch, n_win, hd) — for FMHA we need (1, swa_len, hd)
k_swa = k_out.squeeze(0).unsqueeze(0) # (1, swa_len, hd)
v_swa = k_swa.clone()
return k_swa, v_swa
def read_swa_view(self) -> "SWAView":
"""Return a typed view of the SWA window for this batch."""
return SWAView(
@@ -111,6 +250,11 @@ class LayerCacheHandle:
block_lens=self.block_lens,
)
def __post_init__(self):
# Initialize _num_query_heads (must be set by the manager at construction)
if not hasattr(self, '_num_query_heads'):
self._num_query_heads = 0
# Typed views — simple dataclasses, no logic. The FMHA / indexer / SWA
# kernels accept these to keep their signatures clean.

View File

@@ -173,15 +173,18 @@ class KVCacheManager:
block_table = None
block_lens = None
return LayerCacheHandle(
handle = LayerCacheHandle(
paged=paged,
state=state,
schema=self.schemas[layer_idx],
request_slots=request_slots,
positions=positions,
request_ids=request_ids,
block_table=block_table,
block_lens=block_lens,
)
handle.num_query_heads = self.config.num_query_heads
return handle
# ------------------------------------------------------------------
# Diagnostics

115
dsv4/kernels/cache/gather.py vendored Normal file
View File

@@ -0,0 +1,115 @@
"""Python wrappers for cache gather kernels.
Provides the bridge between LayerCacheHandle and the raw CUDA gather ops.
Loaded via torch.utils.cpp_extension (JIT, sm_100a).
Three gather paths:
1. gather_compressed_kv: CSA — top-k selected compressed entries
2. gather_all_compressed_kv: HCA — all compressed entries (no indexer)
3. gather_swa_kv: SWA window entries from state cache
All outputs are dense BF16 tensors ready for FMHA consumption.
"""
import os
import torch
from torch.utils.cpp_extension import load
_kernel_module = None
_kernel_dir = os.path.join(os.path.dirname(__file__), "cuda")
def _get_kernel_module():
global _kernel_module
if _kernel_module is not None:
return _kernel_module
_kernel_module = load(
name="cache_gather",
sources=[
os.path.join(_kernel_dir, "gather_kv.cu"),
os.path.join(_kernel_dir, "gather_swa.cu"),
],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def gather_compressed_kv(
entries_fp8: torch.Tensor, # [num_blocks, epb, fp8_dim] uint8
entries_rope: torch.Tensor, # [num_blocks, epb, rope_dim] BF16
inv_scale: torch.Tensor, # [num_blocks, epb] FP32
topk_indices: torch.Tensor, # [T, top_k] int32
block_table: torch.Tensor, # [batch, max_logical_blocks] int32
entries_per_block: int,
head_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""Gather top-k compressed KV entries into a dense BF16 tile.
Returns: (T, top_k, head_dim) BF16
"""
T = topk_indices.size(0)
top_k = topk_indices.size(1)
output = torch.zeros(T, top_k, head_dim, dtype=torch.bfloat16, device=entries_fp8.device)
mod = _get_kernel_module()
mod.gather_kv(
entries_fp8, entries_rope, inv_scale,
topk_indices, block_table, output,
entries_per_block, rope_dim,
)
return output
def gather_all_compressed_kv(
entries_fp8: torch.Tensor, # [num_blocks, epb, fp8_dim] uint8
entries_rope: torch.Tensor, # [num_blocks, epb, rope_dim] BF16
inv_scale: torch.Tensor, # [num_blocks, epb] FP32
block_table: torch.Tensor, # [batch, max_logical_blocks] int32
block_lens: torch.Tensor, # [batch] int32
entries_per_block: int,
head_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""Gather ALL compressed KV entries for HCA (dense attention).
Returns: (batch, total_entries, head_dim) BF16 where
total_entries = block_lens.sum() * entries_per_block (padded to max)
"""
batch = block_table.size(0)
max_blocks = block_lens.max().item()
total_entries = max_blocks * entries_per_block
output = torch.zeros(batch, total_entries, head_dim, dtype=torch.bfloat16, device=entries_fp8.device)
mod = _get_kernel_module()
mod.gather_all_compressed(
entries_fp8, entries_rope, inv_scale,
block_table, block_lens, output,
entries_per_block, rope_dim,
)
return output
def gather_swa_kv(
swa_fp8: torch.Tensor, # [max_req, n_win, fp8_dim] uint8
swa_rope: torch.Tensor, # [max_req, n_win, rope_dim] BF16
swa_inv: torch.Tensor, # [max_req, n_win] FP32
swa_pos: torch.Tensor, # [max_req, n_win] int32
request_slots: torch.Tensor, # [batch] int32
head_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""Gather SWA window entries into a dense BF16 tile.
Returns: (batch, n_win, head_dim) BF16
"""
batch = request_slots.size(0)
n_win = swa_fp8.size(1)
output = torch.zeros(batch, n_win, head_dim, dtype=torch.bfloat16, device=swa_fp8.device)
mod = _get_kernel_module()
mod.gather_swa(
swa_fp8, swa_rope, swa_inv, swa_pos, request_slots,
output, head_dim, rope_dim,
)
return output

View File

@@ -0,0 +1,177 @@
// gather_swa.cu — Gather SWA window entries into a dense BF16 tile.
//
// Reads from the state cache's SWA ring buffer (FP8 + BF16 split layout).
// One CTA per request. Each CTA iterates over the n_win positions in the
// ring buffer, dequantizes FP8 → BF16, concatenates the RoPE half, and
// writes to the dense output tensor.
//
// Output shape: [batch, n_win, head_dim] BF16 — consumed by FMHA.
// Positions with swa_pos == -1 (unused slots) are zero-filled.
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAException.h>
__global__ void gather_swa_kernel(
// Inputs
const uint8_t* __restrict__ swa_fp8, // [max_req, n_win, fp8_dim]
const __nv_bfloat16* __restrict__ swa_rope, // [max_req, n_win, rope_dim]
const float* __restrict__ swa_inv, // [max_req, n_win]
const int32_t* __restrict__ swa_pos, // [max_req, n_win]
const int32_t* __restrict__ request_slots, // [batch] — state cache slot per request
// Output
__nv_bfloat16* __restrict__ output, // [batch, n_win, head_dim] BF16
// Geometry
int batch, int n_win, int head_dim, int rope_dim, int max_requests
) {
int fp8_dim = head_dim - rope_dim;
int b = blockIdx.x;
if (b >= batch) return;
int slot = request_slots[b];
for (int w = 0; w < n_win; w++) {
int pos = swa_pos[slot * n_win + w];
int out_row = b * n_win * head_dim + w * head_dim;
if (pos < 0) {
// Unused slot — zero fill
for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {
output[out_row + d] = __float2bfloat16(0.0f);
}
continue;
}
float s = swa_inv[slot * n_win + w];
// Dequantize FP8 half
for (int d = threadIdx.x; d < fp8_dim; d += blockDim.x) {
uint8_t raw = swa_fp8[(slot * n_win + w) * fp8_dim + d];
__nv_fp8_e4m3 fp8_val;
fp8_val.__x = raw;
float dequant = (float)fp8_val * s;
output[out_row + d] = __float2bfloat16(dequant);
}
// Copy BF16 RoPE half
for (int d = threadIdx.x; d < rope_dim; d += blockDim.x) {
output[out_row + fp8_dim + d] = swa_rope[(slot * n_win + w) * rope_dim + d];
}
}
}
void gather_swa_cuda(
torch::Tensor swa_fp8,
torch::Tensor swa_rope,
torch::Tensor swa_inv,
torch::Tensor swa_pos,
torch::Tensor request_slots,
torch::Tensor output,
int64_t head_dim, int64_t rope_dim
) {
int batch = request_slots.size(0);
int n_win = swa_fp8.size(1);
int max_requests = swa_fp8.size(0);
int threads = 128;
gather_swa_kernel<<<batch, threads>>>(
swa_fp8.data_ptr<uint8_t>(),
reinterpret_cast<const __nv_bfloat16*>(swa_rope.data_ptr<at::BFloat16>()),
swa_inv.data_ptr<float>(),
swa_pos.data_ptr<int32_t>(),
request_slots.data_ptr<int32_t>(),
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()),
batch, n_win, (int)head_dim, (int)rope_dim, max_requests
);
C10_CUDA_CHECK(cudaGetLastError());
}
// gather_all_compressed_kv_kernel — Gather ALL compressed entries for HCA
// (no top-k, dense attention over the short compressed sequence).
//
// One CTA per request. Iterates over all valid blocks in the block_table,
// dequantizes all entries, writes to a dense output.
// Output: [batch, total_entries, head_dim] BF16
__global__ void gather_all_compressed_kernel(
const uint8_t* __restrict__ entries_fp8,
const __nv_bfloat16* __restrict__ entries_rope,
const float* __restrict__ inv_scale,
const int32_t* __restrict__ block_table, // [batch, max_logical_blocks]
const int32_t* __restrict__ block_lens, // [batch] — valid blocks per request
__nv_bfloat16* __restrict__ output, // [batch, total_entries, head_dim]
int batch, int entries_per_block, int head_dim,
int rope_dim, int max_logical_blocks
) {
int fp8_dim = head_dim - rope_dim;
int b = blockIdx.x;
if (b >= batch) return;
int n_blocks = block_lens[b];
int out_idx = 0;
for (int lb = 0; lb < n_blocks; lb++) {
int phys_block = block_table[b * max_logical_blocks + lb];
if (phys_block < 0) continue;
for (int epb = 0; epb < entries_per_block; epb++) {
int block_entry = phys_block * entries_per_block + epb;
float s = inv_scale[block_entry];
int out_row = (b * n_blocks * entries_per_block + out_idx) * head_dim;
// Dequantize FP8 half
for (int d = threadIdx.x; d < fp8_dim; d += blockDim.x) {
uint8_t raw = entries_fp8[block_entry * fp8_dim + d];
__nv_fp8_e4m3 fp8_val;
fp8_val.__x = raw;
float dequant = (float)fp8_val * s;
output[out_row + d] = __float2bfloat16(dequant);
}
// Copy BF16 RoPE half
for (int d = threadIdx.x; d < rope_dim; d += blockDim.x) {
output[out_row + fp8_dim + d] = entries_rope[block_entry * rope_dim + d];
}
out_idx++;
}
}
}
void gather_all_compressed_cuda(
torch::Tensor entries_fp8,
torch::Tensor entries_rope,
torch::Tensor inv_scale,
torch::Tensor block_table,
torch::Tensor block_lens,
torch::Tensor output,
int64_t entries_per_block, int64_t rope_dim
) {
int batch = block_table.size(0);
int head_dim = entries_fp8.size(2) + entries_rope.size(2);
int max_logical_blocks = block_table.size(1);
int threads = 128;
gather_all_compressed_kernel<<<batch, threads>>>(
entries_fp8.data_ptr<uint8_t>(),
reinterpret_cast<const __nv_bfloat16*>(entries_rope.data_ptr<at::BFloat16>()),
inv_scale.data_ptr<float>(),
block_table.data_ptr<int32_t>(),
block_lens.data_ptr<int32_t>(),
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()),
batch, (int)entries_per_block, (int)head_dim,
(int)rope_dim, max_logical_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_swa", &gather_swa_cuda, "Gather SWA window into dense BF16 tile");
m.def("gather_all_compressed", &gather_all_compressed_cuda, "Gather all compressed KV for HCA");
}