fix: three indexer bugs — weight path, comp_idx_buf width, scoring einsum
1. Indexer.load: weights at *.indexer.kv_proj not *.indexer.compressor.kv_proj 2. KVCache.comp_idx_buf: width=ihd (128) not head_dim (512); parametric via indexer_key_dim 3. Indexer.forward: stored keys are (n_comp, ihd) not (n_comp, n_ih, ihd); einsum changed from 'tnd,cnd->tnc' to 'tnd,cd->tnc' — key shared across indexer heads (paper's c_I = ihd = 128, one vector per compressed block) Also removed probe diagnostics (COMPRESSOR BUFFERING, COMPRESSOR OUT, INDEXER SKIP, RESHAPE FAILURE, indexer load state) — served their purpose.
This commit is contained in:
@@ -313,9 +313,7 @@ class Compressor:
|
||||
self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||||
|
||||
def forward(self, hidden_states, positions):
|
||||
if self.ratio == 0 or self.kv_lin is None:
|
||||
print(f" COMPRESSOR EARLY RETURN: ratio={self.ratio} kv_lin={self.kv_lin is not None} hd={self.hd} kv_dim={self.kv_dim}", flush=True)
|
||||
return None, None, None
|
||||
if self.ratio == 0 or self.kv_lin is None: return None, None, None
|
||||
T = hidden_states.shape[0]; r = self.ratio; dev = hidden_states.device
|
||||
|
||||
# P7: Buffer decode steps until we have a complete block.
|
||||
@@ -332,7 +330,6 @@ class Compressor:
|
||||
self._pos_buffer[self._buf_len] = positions[0] if positions.numel() == 1 else positions[self._buf_len]
|
||||
self._buf_len += 1
|
||||
if self._buf_len < r:
|
||||
print(f" COMPRESSOR BUFFERING: hd={self.hd} buf_len={self._buf_len} r={r}", flush=True)
|
||||
return None, None, None # Not enough tokens yet
|
||||
# We have a full buffer — use it
|
||||
hidden_states = self._hs_buffer[:self._buf_len]
|
||||
@@ -362,11 +359,6 @@ class Compressor:
|
||||
bi = torch.arange(n_complete, device=dev)
|
||||
pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1)
|
||||
comp_pos = positions[pos_idx]
|
||||
# INDEXER PROBE: Compressor output shape
|
||||
ident = f"hd={self.hd} kv_dim={self.kv_dim} ratio={self.ratio} is_csa={self.is_csa}"
|
||||
print(f" COMPRESSOR OUT [{ident}]: compressed.shape={tuple(compressed.shape)} "
|
||||
f"dtype={compressed.dtype} stride={compressed.stride()} "
|
||||
f"contig={compressed.is_contiguous()}", flush=True)
|
||||
return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev)
|
||||
|
||||
# =====================================================================
|
||||
@@ -406,8 +398,6 @@ class Indexer:
|
||||
|
||||
def forward(self, q_lora, hidden_states, comp_indexer_kv, positions, layer_idx=None):
|
||||
if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0:
|
||||
print(f" INDEXER SKIP L{layer_idx}: q_b_lin={self.q_b_lin is not None} "
|
||||
f"comp_idx_kv={tuple(comp_indexer_kv.shape) if comp_indexer_kv is not None else None}", flush=True)
|
||||
return None
|
||||
dev = q_lora.device; T = q_lora.shape[0]; n_comp = comp_indexer_kv.shape[0]
|
||||
# INDEXER PROBE: print shapes at layer_idx==0 only
|
||||
@@ -423,40 +413,40 @@ class Indexer:
|
||||
print(f" self.wp_lin.in_features={self.wp_lin.in_features} out_features={self.wp_lin.out_features}", flush=True)
|
||||
if self.compressor is not None:
|
||||
print(f" self.compressor.kv_dim={self.compressor.kv_dim} ratio={self.compressor.ratio} hd={self.compressor.hd}", flush=True)
|
||||
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd)
|
||||
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd)
|
||||
w_h = self.wp_lin(hidden_states) # (T, n_ih)
|
||||
# INDEXER PROBE: try reshape, catch failure
|
||||
try:
|
||||
k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd)
|
||||
except RuntimeError as e:
|
||||
print(f" !!! RESHAPE FAILURE L{li} !!!", flush=True)
|
||||
print(f" comp_indexer_kv.shape = {tuple(comp_indexer_kv.shape)}", flush=True)
|
||||
print(f" tried to reshape to ({n_comp}, {self.n_ih}, {self.ihd})", flush=True)
|
||||
print(f" total elements: have {comp_indexer_kv.numel()}, need {n_comp * self.n_ih * self.ihd}", flush=True)
|
||||
raise
|
||||
# Stored indexer keys are (n_comp, ihd) — one vector per compressed block,
|
||||
# shared across all indexer heads (paper's c_I = ihd = 128).
|
||||
# NOT (n_comp, n_ih, ihd) — there is no per-head key decomposition.
|
||||
k_idx = comp_indexer_kv # (n_comp, ihd)
|
||||
if li == 0:
|
||||
print(f"--- INDEXER L0 SCORING TENSORS ---", flush=True)
|
||||
print(f" q_idx: shape={tuple(q_idx.shape)} dtype={q_idx.dtype}", flush=True)
|
||||
print(f" k_idx: shape={tuple(k_idx.shape)} dtype={k_idx.dtype}", flush=True)
|
||||
print(f" w_h: shape={tuple(w_h.shape)} dtype={w_h.dtype}", flush=True)
|
||||
scores = torch.einsum('tnd,cnd->tnc', q_idx.float(), k_idx.float())
|
||||
scores = F.relu(scores); total = (scores * w_h.unsqueeze(-1).float()).sum(1)
|
||||
# Weighted ReLU MQA scoring (eq. 16):
|
||||
# score(t, c) = sum_h w_h(t,h) * ReLU(q(t,h) · k(c))
|
||||
# k is shared across heads: einsum 'tnd,cd->tnc' (c=n_comp, d=ihd)
|
||||
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) # (T, n_ih, n_comp)
|
||||
scores = F.relu(scores)
|
||||
total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp)
|
||||
tk = min(self.top_k, n_comp); _, idx = total.topk(tk, -1); return idx
|
||||
|
||||
# =====================================================================
|
||||
# KV Cache
|
||||
# =====================================================================
|
||||
class KVCache:
|
||||
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0'):
|
||||
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0', indexer_key_dim=128):
|
||||
self.hd, self.ws, self.dev = head_dim, window_size, device
|
||||
self.idx_key_dim = indexer_key_dim
|
||||
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
|
||||
self.swa_len, self.swa_head = 0, 0
|
||||
# P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth)
|
||||
self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device)
|
||||
# Indexer compressed keys are width ihd (128), NOT head_dim (512)
|
||||
self.comp_idx_buf = torch.zeros(max_comp, 128, dtype=torch.bfloat16, device=device) # PROBE: width=ihd
|
||||
# Indexer compressed keys: width = ihd (c_I in the paper), NOT head_dim
|
||||
self.comp_idx_buf = torch.zeros(max_comp, indexer_key_dim, dtype=torch.bfloat16, device=device)
|
||||
self.n_comp = 0
|
||||
self._has_idx = False
|
||||
|
||||
@@ -1070,7 +1060,7 @@ def main():
|
||||
n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128); itk = cfg.get("index_topk", 1024)
|
||||
for li in range(n_layers):
|
||||
dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128
|
||||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev)
|
||||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), device=dev, indexer_key_dim=ihd)
|
||||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||||
|
||||
@@ -1089,14 +1079,6 @@ def main():
|
||||
if li in compressors: compressors[li].load(layer_w[li], pfx, dev=f"cuda:{li % NUM_GPUS}")
|
||||
if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer", dev=f"cuda:{li % NUM_GPUS}")
|
||||
print(" Compressors/indexers loaded")
|
||||
# Probe: verify indexer compressor state
|
||||
for li in sorted(indexers.keys())[:3]:
|
||||
idx = indexers[li]
|
||||
print(f" Indexer L{li}: q_b_lin={idx.q_b_lin is not None} wp_lin={idx.wp_lin is not None} "
|
||||
f"compressor={idx.compressor is not None}", flush=True)
|
||||
if idx.compressor is not None:
|
||||
print(f" compressor: hd={idx.compressor.hd} kv_dim={idx.compressor.kv_dim} "
|
||||
f"kv_lin={idx.compressor.kv_lin is not None} gate_lin={idx.compressor.gate_lin is not None}", flush=True)
|
||||
|
||||
# ---- Phase 3: Inference ----
|
||||
print(f"\nPhase 3: Inference")
|
||||
|
||||
Reference in New Issue
Block a user