From 60309ef1247858bbf2aeef48dae9cc84b5800016 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 07:39:37 +0000 Subject: [PATCH] =?UTF-8?q?Batched=20prefill:=20replace=20T=3D1=20token-by?= =?UTF-8?q?-token=20with=20chunked=20T=E2=89=A4128=20batch=20processing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Process prefill tokens in chunks of up to 128 (FMHA T≤128 constraint) - Each chunk goes through ALL 61 layers before the next chunk - KV cache append_swa, compressor, indexer all already support T>1 - FMHA dispatches to dsv4_attention_mixed_fp8_prefill for T>1 - For T>128: splits into multiple launches automatically - mHC, Router, MoE, Nvfp4Linear all handle M>1 natively - Eliminates ~N_prefill * 61 per-token overhead from the old loop --- single_shot_inference.py | 47 +++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index c9246a8c..ef86d48b 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1473,18 +1473,30 @@ def main(): all_tokens = generated.copy() print(f"Input: {len(generated)} tokens") - # Prefill — one token at a time (decode-style; TODO: batched prefill) - print(f"Prefilling {len(generated)} tokens...") - # Pre-allocate prefill buffers — no per-step torch.tensor() - pre_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0') - pre_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0') - pre_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0') - for pi, tid_val in enumerate(generated): + # Batched prefill — process tokens in chunks of up to 128 (FMHA T≤128 constraint) + PREFILL_CHUNK = 128 # max T per FMHA launch; split larger prefills into chunks + n_prefill = len(generated) + print(f"Batched prefill: {n_prefill} tokens, chunk_size={PREFILL_CHUNK}") + prefill_ids = torch.tensor(generated, dtype=torch.long, device='cuda:0') + prefill_ids32 = prefill_ids.to(torch.int32) + all_positions = torch.arange(n_prefill, dtype=torch.long, device='cuda:0') + + # Process chunks: each chunk goes through ALL 61 layers before the next chunk. + # This ensures KV cache is populated correctly for each layer. + chunk_starts = list(range(0, n_prefill, PREFILL_CHUNK)) + X = None # will be set by first chunk's embedding + for ci, cs in enumerate(chunk_starts): + ce = min(cs + PREFILL_CHUNK, n_prefill) + chunk_len = ce - cs t1 = time.time() - pre_tid_buf[0] = tid_val - pre_tid32_buf[0] = tid_val - pre_pos_buf[0] = pi - X = mHCLayer.init_state(embed(pre_tid_buf)) + + # Embed chunk tokens: (chunk_len, d) + chunk_ids = prefill_ids[cs:ce] + chunk_ids32 = prefill_ids32[cs:ce] + chunk_positions = all_positions[cs:ce] + chunk_embed = embed(chunk_ids) # (chunk_len, d) BF16 + X = mHCLayer.init_state(chunk_embed) # (chunk_len, n_hc, d) BF16 + for li in range(n_layers): gpu = li % NUM_GPUS if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") @@ -1493,7 +1505,7 @@ def main(): X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li), - kv_caches[li], pre_pos_buf, pre_tid32_buf, + kv_caches[li], chunk_positions, chunk_ids32, compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li), @@ -1501,15 +1513,14 @@ def main(): ) except Exception as e: torch.cuda.synchronize() - err = torch.cuda.current_stream(gpu).query() - print(f" CRASH at token {pi} layer {li} gpu {gpu}: {e}", flush=True) + print(f" CRASH at chunk {ci} (tokens {cs}-{ce-1}) layer {li} gpu {gpu}: {e}", flush=True) raise - if VERBOSE >= 2 and pi == 0 and li < 3: + if VERBOSE >= 2 and ci == 0 and li < 3: torch.cuda.synchronize(gpu) - print(f" Token {pi} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True) + print(f" Chunk {ci} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True) X = X.to('cuda:0'); torch.cuda.set_device(0) - if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True) - print(f" Prefill done ({time.time()-t0:.1f}s)") + print(f" Chunk {ci+1}/{len(chunk_starts)} tokens {cs}-{ce-1} ({chunk_len} tok): {time.time()-t1:.2f}s", flush=True) + print(f" Batched prefill done ({time.time()-t0:.1f}s)") if _args.prefill_only: print("Prefill-only mode, stopping."); return