diff --git a/single_shot_inference.py b/single_shot_inference.py index f1843525..4b7f8504 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -349,8 +349,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc, v_new = k_new.clone() # ==== KV cache: append new K,V ==== - kv_cache.append(li, k_new, v_new) - k_full, v_full = kv_cache.get(li) # (1, seq_len, hd) + kv_cache.append(0, k_new, v_new) + k_full, v_full = kv_cache.get(0) # (1, seq_len, hd) seq_len = k_full.shape[1] # ==== RoPE ==== @@ -453,7 +453,7 @@ def main(): # ==== KV cache (gpu0, moves to target GPU per layer) ==== kv_caches = {} for li in range(n_layers): - kv_caches[li] = SimpleKVCache(1, hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}") + kv_caches[li] = SimpleKVCache(n_layers=1, head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}") # ==== Phase 2: Compile ==== print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")