diff --git a/single_shot_inference.py b/single_shot_inference.py index c9246a8c..ef86d48b 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1473,18 +1473,30 @@ def main(): all_tokens = generated.copy() print(f"Input: {len(generated)} tokens") - # Prefill — one token at a time (decode-style; TODO: batched prefill) - print(f"Prefilling {len(generated)} tokens...") - # Pre-allocate prefill buffers — no per-step torch.tensor() - pre_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0') - pre_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0') - pre_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0') - for pi, tid_val in enumerate(generated): + # Batched prefill — process tokens in chunks of up to 128 (FMHA T≤128 constraint) + PREFILL_CHUNK = 128 # max T per FMHA launch; split larger prefills into chunks + n_prefill = len(generated) + print(f"Batched prefill: {n_prefill} tokens, chunk_size={PREFILL_CHUNK}") + prefill_ids = torch.tensor(generated, dtype=torch.long, device='cuda:0') + prefill_ids32 = prefill_ids.to(torch.int32) + all_positions = torch.arange(n_prefill, dtype=torch.long, device='cuda:0') + + # Process chunks: each chunk goes through ALL 61 layers before the next chunk. + # This ensures KV cache is populated correctly for each layer. + chunk_starts = list(range(0, n_prefill, PREFILL_CHUNK)) + X = None # will be set by first chunk's embedding + for ci, cs in enumerate(chunk_starts): + ce = min(cs + PREFILL_CHUNK, n_prefill) + chunk_len = ce - cs t1 = time.time() - pre_tid_buf[0] = tid_val - pre_tid32_buf[0] = tid_val - pre_pos_buf[0] = pi - X = mHCLayer.init_state(embed(pre_tid_buf)) + + # Embed chunk tokens: (chunk_len, d) + chunk_ids = prefill_ids[cs:ce] + chunk_ids32 = prefill_ids32[cs:ce] + chunk_positions = all_positions[cs:ce] + chunk_embed = embed(chunk_ids) # (chunk_len, d) BF16 + X = mHCLayer.init_state(chunk_embed) # (chunk_len, n_hc, d) BF16 + for li in range(n_layers): gpu = li % NUM_GPUS if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") @@ -1493,7 +1505,7 @@ def main(): X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li), - kv_caches[li], pre_pos_buf, pre_tid32_buf, + kv_caches[li], chunk_positions, chunk_ids32, compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li), @@ -1501,15 +1513,14 @@ def main(): ) except Exception as e: torch.cuda.synchronize() - err = torch.cuda.current_stream(gpu).query() - print(f" CRASH at token {pi} layer {li} gpu {gpu}: {e}", flush=True) + print(f" CRASH at chunk {ci} (tokens {cs}-{ce-1}) layer {li} gpu {gpu}: {e}", flush=True) raise - if VERBOSE >= 2 and pi == 0 and li < 3: + if VERBOSE >= 2 and ci == 0 and li < 3: torch.cuda.synchronize(gpu) - print(f" Token {pi} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True) + print(f" Chunk {ci} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True) X = X.to('cuda:0'); torch.cuda.set_device(0) - if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True) - print(f" Prefill done ({time.time()-t0:.1f}s)") + print(f" Chunk {ci+1}/{len(chunk_starts)} tokens {cs}-{ce-1} ({chunk_len} tok): {time.time()-t1:.2f}s", flush=True) + print(f" Batched prefill done ({time.time()-t0:.1f}s)") if _args.prefill_only: print("Prefill-only mode, stopping."); return