Add prefill: process prompt tokens to fill KV cache before decoding
This commit is contained in:
@@ -477,12 +477,32 @@ def main():
|
||||
|
||||
generated = input_ids[0].tolist()
|
||||
|
||||
# ==== Prefill: process all prompt tokens at once ====
|
||||
# For now, treat each prompt token as a separate decode step
|
||||
# (proper prefill would use T>1 FMHA, which is wired but not yet
|
||||
# tested with KV cache. One token at a time is correct, just slow.)
|
||||
# ==== Prefill: process prompt tokens to fill KV cache ====
|
||||
print(f"Prefilling {len(generated)} prompt tokens...")
|
||||
for prefill_idx, tid_val in enumerate(generated):
|
||||
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
||||
emb = embed(tid)
|
||||
X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone()
|
||||
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
target_device = f"cuda:{gpu}"
|
||||
if X.device != torch.device(target_device):
|
||||
X = X.to(target_device)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
attn_mhc = attn_mhc_blocks.get(li)
|
||||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||||
rc, rs = rope_caches[gpu]
|
||||
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
|
||||
attn_mhc, ffn_mhc, kv_caches[li], tid, prefill_idx)
|
||||
|
||||
X = X.to('cuda:0')
|
||||
torch.cuda.set_device(0)
|
||||
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
|
||||
|
||||
all_tokens = generated.copy() # start with prompt tokens
|
||||
# ==== Decode: generate new tokens ====
|
||||
print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...")
|
||||
|
||||
for step in range(MAX_NEW_TOKENS):
|
||||
t0 = time.time()
|
||||
|
||||
Reference in New Issue
Block a user