diff --git a/single_shot_inference.py b/single_shot_inference.py index 4b7f8504..ef384ef8 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -477,12 +477,32 @@ def main(): generated = input_ids[0].tolist() - # ==== Prefill: process all prompt tokens at once ==== - # For now, treat each prompt token as a separate decode step - # (proper prefill would use T>1 FMHA, which is wired but not yet - # tested with KV cache. One token at a time is correct, just slow.) + # ==== Prefill: process prompt tokens to fill KV cache ==== + print(f"Prefilling {len(generated)} prompt tokens...") + for prefill_idx, tid_val in enumerate(generated): + tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') + emb = embed(tid) + X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone() + + for li in range(n_layers): + gpu = li % NUM_GPUS + target_device = f"cuda:{gpu}" + if X.device != torch.device(target_device): + X = X.to(target_device) + torch.cuda.set_device(gpu) + + attn_mhc = attn_mhc_blocks.get(li) + ffn_mhc = ffn_mhc_blocks.get(li) + rc, rs = rope_caches[gpu] + X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, + attn_mhc, ffn_mhc, kv_caches[li], tid, prefill_idx) + + X = X.to('cuda:0') + torch.cuda.set_device(0) + print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)") - all_tokens = generated.copy() # start with prompt tokens + # ==== Decode: generate new tokens ==== + print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...") for step in range(MAX_NEW_TOKENS): t0 = time.time()