diff --git a/single_shot_inference.py b/single_shot_inference.py index 37e0b499..795c7c9d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -34,6 +34,7 @@ def parse_args(): p.add_argument('--prefill-tokens', type=str, default=None, help='Override prompt tokens as comma-separated IDs (e.g. "1,128803,313,128804")') p.add_argument('--cuda-graph', action='store_true', help='Capture CUDA graph per layer for decode (eliminates Python dispatch overhead)') + p.add_argument('--max-context', type=int, default=8192, help='Target max context length (determines KV cache pre-allocation)') return p.parse_args() _args = parse_args() @@ -436,9 +437,11 @@ class Indexer: # KV Cache # ===================================================================== class KVCache: - def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0', indexer_key_dim=128): + def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0', + indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024): self.hd, self.ws, self.dev = head_dim, window_size, device self.idx_key_dim = indexer_key_dim + self.ratio = compress_ratio self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device) self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device) self.swa_len, self.swa_head = 0, 0 @@ -447,6 +450,8 @@ class KVCache: self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device) # Indexer compressed keys: width = ihd (c_I in the paper), NOT head_dim self.comp_idx_buf = torch.zeros(max_comp, indexer_key_dim, dtype=torch.bfloat16, device=device) + # Pre-allocated gather buffer — top_k compressed + SWA window, zero torch.cat on hot path + self.gather_buf = torch.zeros(indexer_top_k + window_size, head_dim, dtype=torch.bfloat16, device=device) self.n_comp = 0 self._has_idx = False @@ -484,11 +489,14 @@ class KVCache: return self.comp_idx_buf[:self.n_comp] if self._has_idx and self.n_comp > 0 else None def get_swa(self): + """Return SWA KV and positions as views (no clone). Caller copies into gather_buf.""" if self.swa_len == 0: - return torch.zeros(0, self.hd, device=self.dev, dtype=torch.bfloat16), torch.zeros(0, device=self.dev, dtype=torch.long) - if self.swa_len < self.ws: return self.swa[:self.swa_len].clone(), self.swa_pos[:self.swa_len].clone() - idx = torch.arange(self.swa_head, self.swa_head + self.ws) % self.ws - return self.swa[idx].clone(), self.swa_pos[idx].clone() + return self.swa[:0], self.swa_pos[:0] + if self.swa_len < self.ws: + return self.swa[:self.swa_len], self.swa_pos[:self.swa_len] + # Ring buffer wrap — gather non-contiguous rows + idx = torch.arange(self.swa_head, self.swa_head + self.ws, device=self.dev) % self.ws + return self.swa[idx], self.swa_pos[idx] # ===================================================================== # HcHead @@ -590,17 +598,30 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, if indexer is not None and ratio == 4: topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li) - # 5. Gather KV + # 5. Gather KV — pre-allocated buffer, zero torch.cat on hot path _pt('gather_start') - swa_kv, swa_pos = kv_cache.get_swa() + swa_kv, _swa_pos = kv_cache.get_swa() + swa_len = swa_kv.shape[0] + gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated if kv_cache.comp_kv is not None and kv_cache.n_comp > 0: if ratio == 4: assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken" tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1) - all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0) - elif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0) - else: all_kv = swa_kv - else: all_kv = swa_kv + n_tk = tk.shape[0] + gbuf[:n_tk] = kv_cache.comp_kv[tk] + gbuf[n_tk:n_tk + swa_len] = swa_kv + all_kv = gbuf[:n_tk + swa_len] + elif ratio > 4: + n_comp = kv_cache.n_comp + gbuf[:n_comp] = kv_cache.comp_kv + gbuf[n_comp:n_comp + swa_len] = swa_kv + all_kv = gbuf[:n_comp + swa_len] + else: + gbuf[:swa_len] = swa_kv + all_kv = gbuf[:swa_len] + else: + gbuf[:swa_len] = swa_kv + all_kv = gbuf[:swa_len] seq_len = all_kv.shape[0] if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a @@ -1059,9 +1080,14 @@ def main(): # KV caches, compressors, indexers kv_caches, compressors, indexers = {}, {}, {} n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128); itk = cfg.get("index_topk", 1024) + max_ctx = _args.max_context + print(f" Max context: {max_ctx} tokens (governs KV cache pre-allocation)") for li in range(n_layers): dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128 - kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev, indexer_key_dim=ihd) + # C1: max_comp derived from target context and compress ratio + max_comp = (max_ctx + ratio - 1) // ratio if ratio > 0 else 0 + kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp, device=dev, + indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk) if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)