Fix KV cache: use index 0 (one-layer cache per layer instance)
This commit is contained in:
@@ -349,8 +349,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc,
|
||||
v_new = k_new.clone()
|
||||
|
||||
# ==== KV cache: append new K,V ====
|
||||
kv_cache.append(li, k_new, v_new)
|
||||
k_full, v_full = kv_cache.get(li) # (1, seq_len, hd)
|
||||
kv_cache.append(0, k_new, v_new)
|
||||
k_full, v_full = kv_cache.get(0) # (1, seq_len, hd)
|
||||
seq_len = k_full.shape[1]
|
||||
|
||||
# ==== RoPE ====
|
||||
@@ -453,7 +453,7 @@ def main():
|
||||
# ==== KV cache (gpu0, moves to target GPU per layer) ====
|
||||
kv_caches = {}
|
||||
for li in range(n_layers):
|
||||
kv_caches[li] = SimpleKVCache(1, hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
|
||||
kv_caches[li] = SimpleKVCache(n_layers=1, head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
|
||||
|
||||
# ==== Phase 2: Compile ====
|
||||
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
|
||||
|
||||
Reference in New Issue
Block a user