fix: three indexer bugs — weight path, comp_idx_buf width, scoring einsum

1. Indexer.load: weights at *.indexer.kv_proj not *.indexer.compressor.kv_proj
2. KVCache.comp_idx_buf: width=ihd (128) not head_dim (512); parametric via indexer_key_dim
3. Indexer.forward: stored keys are (n_comp, ihd) not (n_comp, n_ih, ihd);
   einsum changed from 'tnd,cnd->tnc' to 'tnd,cd->tnc' — key shared across indexer heads
   (paper's c_I = ihd = 128, one vector per compressed block)

Also removed probe diagnostics (COMPRESSOR BUFFERING, COMPRESSOR OUT, INDEXER SKIP,
RESHAPE FAILURE, indexer load state) — served their purpose.
This commit is contained in:
2026-06-02 05:53:10 +00:00
parent 0fbf28dd54
commit bb82cdec85

View File

@@ -313,9 +313,7 @@ class Compressor:
self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
def forward(self, hidden_states, positions):
if self.ratio == 0 or self.kv_lin is None:
print(f" COMPRESSOR EARLY RETURN: ratio={self.ratio} kv_lin={self.kv_lin is not None} hd={self.hd} kv_dim={self.kv_dim}", flush=True)
return None, None, None
if self.ratio == 0 or self.kv_lin is None: return None, None, None
T = hidden_states.shape[0]; r = self.ratio; dev = hidden_states.device
# P7: Buffer decode steps until we have a complete block.
@@ -332,7 +330,6 @@ class Compressor:
self._pos_buffer[self._buf_len] = positions[0] if positions.numel() == 1 else positions[self._buf_len]
self._buf_len += 1
if self._buf_len < r:
print(f" COMPRESSOR BUFFERING: hd={self.hd} buf_len={self._buf_len} r={r}", flush=True)
return None, None, None # Not enough tokens yet
# We have a full buffer — use it
hidden_states = self._hs_buffer[:self._buf_len]
@@ -362,11 +359,6 @@ class Compressor:
bi = torch.arange(n_complete, device=dev)
pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1)
comp_pos = positions[pos_idx]
# INDEXER PROBE: Compressor output shape
ident = f"hd={self.hd} kv_dim={self.kv_dim} ratio={self.ratio} is_csa={self.is_csa}"
print(f" COMPRESSOR OUT [{ident}]: compressed.shape={tuple(compressed.shape)} "
f"dtype={compressed.dtype} stride={compressed.stride()} "
f"contig={compressed.is_contiguous()}", flush=True)
return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev)
# =====================================================================
@@ -406,8 +398,6 @@ class Indexer:
def forward(self, q_lora, hidden_states, comp_indexer_kv, positions, layer_idx=None):
if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0:
print(f" INDEXER SKIP L{layer_idx}: q_b_lin={self.q_b_lin is not None} "
f"comp_idx_kv={tuple(comp_indexer_kv.shape) if comp_indexer_kv is not None else None}", flush=True)
return None
dev = q_lora.device; T = q_lora.shape[0]; n_comp = comp_indexer_kv.shape[0]
# INDEXER PROBE: print shapes at layer_idx==0 only
@@ -423,40 +413,40 @@ class Indexer:
print(f" self.wp_lin.in_features={self.wp_lin.in_features} out_features={self.wp_lin.out_features}", flush=True)
if self.compressor is not None:
print(f" self.compressor.kv_dim={self.compressor.kv_dim} ratio={self.compressor.ratio} hd={self.compressor.hd}", flush=True)
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd)
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd)
w_h = self.wp_lin(hidden_states) # (T, n_ih)
# INDEXER PROBE: try reshape, catch failure
try:
k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd)
except RuntimeError as e:
print(f" !!! RESHAPE FAILURE L{li} !!!", flush=True)
print(f" comp_indexer_kv.shape = {tuple(comp_indexer_kv.shape)}", flush=True)
print(f" tried to reshape to ({n_comp}, {self.n_ih}, {self.ihd})", flush=True)
print(f" total elements: have {comp_indexer_kv.numel()}, need {n_comp * self.n_ih * self.ihd}", flush=True)
raise
# Stored indexer keys are (n_comp, ihd) — one vector per compressed block,
# shared across all indexer heads (paper's c_I = ihd = 128).
# NOT (n_comp, n_ih, ihd) — there is no per-head key decomposition.
k_idx = comp_indexer_kv # (n_comp, ihd)
if li == 0:
print(f"--- INDEXER L0 SCORING TENSORS ---", flush=True)
print(f" q_idx: shape={tuple(q_idx.shape)} dtype={q_idx.dtype}", flush=True)
print(f" k_idx: shape={tuple(k_idx.shape)} dtype={k_idx.dtype}", flush=True)
print(f" w_h: shape={tuple(w_h.shape)} dtype={w_h.dtype}", flush=True)
scores = torch.einsum('tnd,cnd->tnc', q_idx.float(), k_idx.float())
scores = F.relu(scores); total = (scores * w_h.unsqueeze(-1).float()).sum(1)
# Weighted ReLU MQA scoring (eq. 16):
# score(t, c) = sum_h w_h(t,h) * ReLU(q(t,h) · k(c))
# k is shared across heads: einsum 'tnd,cd->tnc' (c=n_comp, d=ihd)
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) # (T, n_ih, n_comp)
scores = F.relu(scores)
total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp)
tk = min(self.top_k, n_comp); _, idx = total.topk(tk, -1); return idx
# =====================================================================
# KV Cache
# =====================================================================
class KVCache:
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0'):
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0', indexer_key_dim=128):
self.hd, self.ws, self.dev = head_dim, window_size, device
self.idx_key_dim = indexer_key_dim
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
self.swa_len, self.swa_head = 0, 0
# P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth)
self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device)
# Indexer compressed keys are width ihd (128), NOT head_dim (512)
self.comp_idx_buf = torch.zeros(max_comp, 128, dtype=torch.bfloat16, device=device) # PROBE: width=ihd
# Indexer compressed keys: width = ihd (c_I in the paper), NOT head_dim
self.comp_idx_buf = torch.zeros(max_comp, indexer_key_dim, dtype=torch.bfloat16, device=device)
self.n_comp = 0
self._has_idx = False
@@ -1070,7 +1060,7 @@ def main():
n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128); itk = cfg.get("index_topk", 1024)
for li in range(n_layers):
dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev)
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev, indexer_key_dim=ihd)
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
@@ -1089,14 +1079,6 @@ def main():
if li in compressors: compressors[li].load(layer_w[li], pfx, dev=f"cuda:{li % NUM_GPUS}")
if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer", dev=f"cuda:{li % NUM_GPUS}")
print(" Compressors/indexers loaded")
# Probe: verify indexer compressor state
for li in sorted(indexers.keys())[:3]:
idx = indexers[li]
print(f" Indexer L{li}: q_b_lin={idx.q_b_lin is not None} wp_lin={idx.wp_lin is not None} "
f"compressor={idx.compressor is not None}", flush=True)
if idx.compressor is not None:
print(f" compressor: hd={idx.compressor.hd} kv_dim={idx.compressor.kv_dim} "
f"kv_lin={idx.compressor.kv_lin is not None} gate_lin={idx.compressor.gate_lin is not None}", flush=True)
# ---- Phase 3: Inference ----
print(f"\nPhase 3: Inference")