perf: C1/C2/C3 — per-layer max_comp, pre-allocated gather_buf, SWA views

C1: --max-context CLI flag (default 8192). KVCache.max_comp computed from
    (max_context + compress_ratio - 1) // ratio per layer type.
    CSA at 8192 context → 2048 entries. HCA at 8192 → 64 entries.
    No more hardcoded 65536 that wastes memory on HCA layers.

C2: Pre-allocated gather_buf (indexer_top_k + window_size, hd) in KVCache.
    Gather writes compressed+SWA into this buffer via slice assignment.
    Zero torch.cat allocations on the hot decode path.

C3: get_swa returns views (no .clone()). Ring-buffer wrap returns indexed
    views. Caller copies into gather_buf so no aliasing risk.
This commit is contained in:
2026-06-02 06:18:06 +00:00
parent 617da29a5b
commit 4e6c014a7f

View File

@@ -34,6 +34,7 @@ def parse_args():
p.add_argument('--prefill-tokens', type=str, default=None,
help='Override prompt tokens as comma-separated IDs (e.g. "1,128803,313,128804")')
p.add_argument('--cuda-graph', action='store_true', help='Capture CUDA graph per layer for decode (eliminates Python dispatch overhead)')
p.add_argument('--max-context', type=int, default=8192, help='Target max context length (determines KV cache pre-allocation)')
return p.parse_args()
_args = parse_args()
@@ -436,9 +437,11 @@ class Indexer:
# KV Cache
# =====================================================================
class KVCache:
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0', indexer_key_dim=128):
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0',
indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024):
self.hd, self.ws, self.dev = head_dim, window_size, device
self.idx_key_dim = indexer_key_dim
self.ratio = compress_ratio
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
self.swa_len, self.swa_head = 0, 0
@@ -447,6 +450,8 @@ class KVCache:
self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device)
# Indexer compressed keys: width = ihd (c_I in the paper), NOT head_dim
self.comp_idx_buf = torch.zeros(max_comp, indexer_key_dim, dtype=torch.bfloat16, device=device)
# Pre-allocated gather buffer — top_k compressed + SWA window, zero torch.cat on hot path
self.gather_buf = torch.zeros(indexer_top_k + window_size, head_dim, dtype=torch.bfloat16, device=device)
self.n_comp = 0
self._has_idx = False
@@ -484,11 +489,14 @@ class KVCache:
return self.comp_idx_buf[:self.n_comp] if self._has_idx and self.n_comp > 0 else None
def get_swa(self):
"""Return SWA KV and positions as views (no clone). Caller copies into gather_buf."""
if self.swa_len == 0:
return torch.zeros(0, self.hd, device=self.dev, dtype=torch.bfloat16), torch.zeros(0, device=self.dev, dtype=torch.long)
if self.swa_len < self.ws: return self.swa[:self.swa_len].clone(), self.swa_pos[:self.swa_len].clone()
idx = torch.arange(self.swa_head, self.swa_head + self.ws) % self.ws
return self.swa[idx].clone(), self.swa_pos[idx].clone()
return self.swa[:0], self.swa_pos[:0]
if self.swa_len < self.ws:
return self.swa[:self.swa_len], self.swa_pos[:self.swa_len]
# Ring buffer wrap — gather non-contiguous rows
idx = torch.arange(self.swa_head, self.swa_head + self.ws, device=self.dev) % self.ws
return self.swa[idx], self.swa_pos[idx]
# =====================================================================
# HcHead
@@ -590,17 +598,30 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
if indexer is not None and ratio == 4:
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li)
# 5. Gather KV
# 5. Gather KV — pre-allocated buffer, zero torch.cat on hot path
_pt('gather_start')
swa_kv, swa_pos = kv_cache.get_swa()
swa_kv, _swa_pos = kv_cache.get_swa()
swa_len = swa_kv.shape[0]
gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated
if kv_cache.comp_kv is not None and kv_cache.n_comp > 0:
if ratio == 4:
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken"
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1)
all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0)
elif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0)
else: all_kv = swa_kv
else: all_kv = swa_kv
n_tk = tk.shape[0]
gbuf[:n_tk] = kv_cache.comp_kv[tk]
gbuf[n_tk:n_tk + swa_len] = swa_kv
all_kv = gbuf[:n_tk + swa_len]
elif ratio > 4:
n_comp = kv_cache.n_comp
gbuf[:n_comp] = kv_cache.comp_kv
gbuf[n_comp:n_comp + swa_len] = swa_kv
all_kv = gbuf[:n_comp + swa_len]
else:
gbuf[:swa_len] = swa_kv
all_kv = gbuf[:swa_len]
else:
gbuf[:swa_len] = swa_kv
all_kv = gbuf[:swa_len]
seq_len = all_kv.shape[0]
if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
@@ -1059,9 +1080,14 @@ def main():
# KV caches, compressors, indexers
kv_caches, compressors, indexers = {}, {}, {}
n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128); itk = cfg.get("index_topk", 1024)
max_ctx = _args.max_context
print(f" Max context: {max_ctx} tokens (governs KV cache pre-allocation)")
for li in range(n_layers):
dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev, indexer_key_dim=ihd)
# C1: max_comp derived from target context and compress ratio
max_comp = (max_ctx + ratio - 1) // ratio if ratio > 0 else 0
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp, device=dev,
indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk)
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)