Add prefill: process prompt tokens to fill KV cache before decoding

This commit is contained in:
2026-05-31 00:18:55 +00:00
parent 178fb5483a
commit b1dd59293a

View File

@@ -477,12 +477,32 @@ def main():
generated = input_ids[0].tolist()
# ==== Prefill: process all prompt tokens at once ====
# For now, treat each prompt token as a separate decode step
# (proper prefill would use T>1 FMHA, which is wired but not yet
# tested with KV cache. One token at a time is correct, just slow.)
# ==== Prefill: process prompt tokens to fill KV cache ====
print(f"Prefilling {len(generated)} prompt tokens...")
for prefill_idx, tid_val in enumerate(generated):
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
emb = embed(tid)
X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone()
for li in range(n_layers):
gpu = li % NUM_GPUS
target_device = f"cuda:{gpu}"
if X.device != torch.device(target_device):
X = X.to(target_device)
torch.cuda.set_device(gpu)
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
attn_mhc, ffn_mhc, kv_caches[li], tid, prefill_idx)
X = X.to('cuda:0')
torch.cuda.set_device(0)
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
all_tokens = generated.copy() # start with prompt tokens
# ==== Decode: generate new tokens ====
print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...")
for step in range(MAX_NEW_TOKENS):
t0 = time.time()