From 5b4c496512004c648b1ca2433795da48e0c4d4aa Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 05:53:10 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20three=20indexer=20bugs=20=E2=80=94=20wei?= =?UTF-8?q?ght=20path,=20comp=5Fidx=5Fbuf=20width,=20scoring=20einsum?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Indexer.load: weights at *.indexer.kv_proj not *.indexer.compressor.kv_proj 2. KVCache.comp_idx_buf: width=ihd (128) not head_dim (512); parametric via indexer_key_dim 3. Indexer.forward: stored keys are (n_comp, ihd) not (n_comp, n_ih, ihd); einsum changed from 'tnd,cnd->tnc' to 'tnd,cd->tnc' — key shared across indexer heads (paper's c_I = ihd = 128, one vector per compressed block) Also removed probe diagnostics (COMPRESSOR BUFFERING, COMPRESSOR OUT, INDEXER SKIP, RESHAPE FAILURE, indexer load state) — served their purpose. --- single_shot_inference.py | 52 +++++++++++++--------------------------- 1 file changed, 17 insertions(+), 35 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 224c634d..9b95384f 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -313,9 +313,7 @@ class Compressor: self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight") def forward(self, hidden_states, positions): - if self.ratio == 0 or self.kv_lin is None: - print(f" COMPRESSOR EARLY RETURN: ratio={self.ratio} kv_lin={self.kv_lin is not None} hd={self.hd} kv_dim={self.kv_dim}", flush=True) - return None, None, None + if self.ratio == 0 or self.kv_lin is None: return None, None, None T = hidden_states.shape[0]; r = self.ratio; dev = hidden_states.device # P7: Buffer decode steps until we have a complete block. @@ -332,7 +330,6 @@ class Compressor: self._pos_buffer[self._buf_len] = positions[0] if positions.numel() == 1 else positions[self._buf_len] self._buf_len += 1 if self._buf_len < r: - print(f" COMPRESSOR BUFFERING: hd={self.hd} buf_len={self._buf_len} r={r}", flush=True) return None, None, None # Not enough tokens yet # We have a full buffer — use it hidden_states = self._hs_buffer[:self._buf_len] @@ -362,11 +359,6 @@ class Compressor: bi = torch.arange(n_complete, device=dev) pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1) comp_pos = positions[pos_idx] - # INDEXER PROBE: Compressor output shape - ident = f"hd={self.hd} kv_dim={self.kv_dim} ratio={self.ratio} is_csa={self.is_csa}" - print(f" COMPRESSOR OUT [{ident}]: compressed.shape={tuple(compressed.shape)} " - f"dtype={compressed.dtype} stride={compressed.stride()} " - f"contig={compressed.is_contiguous()}", flush=True) return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev) # ===================================================================== @@ -406,8 +398,6 @@ class Indexer: def forward(self, q_lora, hidden_states, comp_indexer_kv, positions, layer_idx=None): if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0: - print(f" INDEXER SKIP L{layer_idx}: q_b_lin={self.q_b_lin is not None} " - f"comp_idx_kv={tuple(comp_indexer_kv.shape) if comp_indexer_kv is not None else None}", flush=True) return None dev = q_lora.device; T = q_lora.shape[0]; n_comp = comp_indexer_kv.shape[0] # INDEXER PROBE: print shapes at layer_idx==0 only @@ -423,40 +413,40 @@ class Indexer: print(f" self.wp_lin.in_features={self.wp_lin.in_features} out_features={self.wp_lin.out_features}", flush=True) if self.compressor is not None: print(f" self.compressor.kv_dim={self.compressor.kv_dim} ratio={self.compressor.ratio} hd={self.compressor.hd}", flush=True) - q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) + q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd) w_h = self.wp_lin(hidden_states) # (T, n_ih) - # INDEXER PROBE: try reshape, catch failure - try: - k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd) - except RuntimeError as e: - print(f" !!! RESHAPE FAILURE L{li} !!!", flush=True) - print(f" comp_indexer_kv.shape = {tuple(comp_indexer_kv.shape)}", flush=True) - print(f" tried to reshape to ({n_comp}, {self.n_ih}, {self.ihd})", flush=True) - print(f" total elements: have {comp_indexer_kv.numel()}, need {n_comp * self.n_ih * self.ihd}", flush=True) - raise + # Stored indexer keys are (n_comp, ihd) — one vector per compressed block, + # shared across all indexer heads (paper's c_I = ihd = 128). + # NOT (n_comp, n_ih, ihd) — there is no per-head key decomposition. + k_idx = comp_indexer_kv # (n_comp, ihd) if li == 0: print(f"--- INDEXER L0 SCORING TENSORS ---", flush=True) print(f" q_idx: shape={tuple(q_idx.shape)} dtype={q_idx.dtype}", flush=True) print(f" k_idx: shape={tuple(k_idx.shape)} dtype={k_idx.dtype}", flush=True) print(f" w_h: shape={tuple(w_h.shape)} dtype={w_h.dtype}", flush=True) - scores = torch.einsum('tnd,cnd->tnc', q_idx.float(), k_idx.float()) - scores = F.relu(scores); total = (scores * w_h.unsqueeze(-1).float()).sum(1) + # Weighted ReLU MQA scoring (eq. 16): + # score(t, c) = sum_h w_h(t,h) * ReLU(q(t,h) · k(c)) + # k is shared across heads: einsum 'tnd,cd->tnc' (c=n_comp, d=ihd) + scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) # (T, n_ih, n_comp) + scores = F.relu(scores) + total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp) tk = min(self.top_k, n_comp); _, idx = total.topk(tk, -1); return idx # ===================================================================== # KV Cache # ===================================================================== class KVCache: - def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0'): + def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0', indexer_key_dim=128): self.hd, self.ws, self.dev = head_dim, window_size, device + self.idx_key_dim = indexer_key_dim self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device) self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device) self.swa_len, self.swa_head = 0, 0 # P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth) self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device) self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device) - # Indexer compressed keys are width ihd (128), NOT head_dim (512) - self.comp_idx_buf = torch.zeros(max_comp, 128, dtype=torch.bfloat16, device=device) # PROBE: width=ihd + # Indexer compressed keys: width = ihd (c_I in the paper), NOT head_dim + self.comp_idx_buf = torch.zeros(max_comp, indexer_key_dim, dtype=torch.bfloat16, device=device) self.n_comp = 0 self._has_idx = False @@ -1070,7 +1060,7 @@ def main(): n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128); itk = cfg.get("index_topk", 1024) for li in range(n_layers): dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128 - kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev) + kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev, indexer_key_dim=ihd) if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev) @@ -1089,14 +1079,6 @@ def main(): if li in compressors: compressors[li].load(layer_w[li], pfx, dev=f"cuda:{li % NUM_GPUS}") if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer", dev=f"cuda:{li % NUM_GPUS}") print(" Compressors/indexers loaded") - # Probe: verify indexer compressor state - for li in sorted(indexers.keys())[:3]: - idx = indexers[li] - print(f" Indexer L{li}: q_b_lin={idx.q_b_lin is not None} wp_lin={idx.wp_lin is not None} " - f"compressor={idx.compressor is not None}", flush=True) - if idx.compressor is not None: - print(f" compressor: hd={idx.compressor.hd} kv_dim={idx.compressor.kv_dim} " - f"kv_lin={idx.compressor.kv_lin is not None} gate_lin={idx.compressor.gate_lin is not None}", flush=True) # ---- Phase 3: Inference ---- print(f"\nPhase 3: Inference")