Batched prefill: replace T=1 token-by-token with chunked T≤128 batch processing

- Process prefill tokens in chunks of up to 128 (FMHA T≤128 constraint)
- Each chunk goes through ALL 61 layers before the next chunk
- KV cache append_swa, compressor, indexer all already support T>1
- FMHA dispatches to dsv4_attention_mixed_fp8_prefill for T>1
- For T>128: splits into multiple launches automatically
- mHC, Router, MoE, Nvfp4Linear all handle M>1 natively
- Eliminates ~N_prefill * 61 per-token overhead from the old loop
This commit is contained in:
2026-06-03 07:39:37 +00:00
parent 0bf276f8c9
commit 60309ef124

View File

@@ -1473,18 +1473,30 @@ def main():
all_tokens = generated.copy()
print(f"Input: {len(generated)} tokens")
# Prefill — one token at a time (decode-style; TODO: batched prefill)
print(f"Prefilling {len(generated)} tokens...")
# Pre-allocate prefill buffers — no per-step torch.tensor()
pre_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
pre_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0')
pre_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
for pi, tid_val in enumerate(generated):
# Batched prefill — process tokens in chunks of up to 128 (FMHA T≤128 constraint)
PREFILL_CHUNK = 128 # max T per FMHA launch; split larger prefills into chunks
n_prefill = len(generated)
print(f"Batched prefill: {n_prefill} tokens, chunk_size={PREFILL_CHUNK}")
prefill_ids = torch.tensor(generated, dtype=torch.long, device='cuda:0')
prefill_ids32 = prefill_ids.to(torch.int32)
all_positions = torch.arange(n_prefill, dtype=torch.long, device='cuda:0')
# Process chunks: each chunk goes through ALL 61 layers before the next chunk.
# This ensures KV cache is populated correctly for each layer.
chunk_starts = list(range(0, n_prefill, PREFILL_CHUNK))
X = None # will be set by first chunk's embedding
for ci, cs in enumerate(chunk_starts):
ce = min(cs + PREFILL_CHUNK, n_prefill)
chunk_len = ce - cs
t1 = time.time()
pre_tid_buf[0] = tid_val
pre_tid32_buf[0] = tid_val
pre_pos_buf[0] = pi
X = mHCLayer.init_state(embed(pre_tid_buf))
# Embed chunk tokens: (chunk_len, d)
chunk_ids = prefill_ids[cs:ce]
chunk_ids32 = prefill_ids32[cs:ce]
chunk_positions = all_positions[cs:ce]
chunk_embed = embed(chunk_ids) # (chunk_len, d) BF16
X = mHCLayer.init_state(chunk_embed) # (chunk_len, n_hc, d) BF16
for li in range(n_layers):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
@@ -1493,7 +1505,7 @@ def main():
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li),
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], pre_pos_buf, pre_tid32_buf,
kv_caches[li], chunk_positions, chunk_ids32,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li),
@@ -1501,15 +1513,14 @@ def main():
)
except Exception as e:
torch.cuda.synchronize()
err = torch.cuda.current_stream(gpu).query()
print(f" CRASH at token {pi} layer {li} gpu {gpu}: {e}", flush=True)
print(f" CRASH at chunk {ci} (tokens {cs}-{ce-1}) layer {li} gpu {gpu}: {e}", flush=True)
raise
if VERBOSE >= 2 and pi == 0 and li < 3:
if VERBOSE >= 2 and ci == 0 and li < 3:
torch.cuda.synchronize(gpu)
print(f" Token {pi} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True)
print(f" Chunk {ci} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True)
X = X.to('cuda:0'); torch.cuda.set_device(0)
if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True)
print(f" Prefill done ({time.time()-t0:.1f}s)")
print(f" Chunk {ci+1}/{len(chunk_starts)} tokens {cs}-{ce-1} ({chunk_len} tok): {time.time()-t1:.2f}s", flush=True)
print(f" Batched prefill done ({time.time()-t0:.1f}s)")
if _args.prefill_only: print("Prefill-only mode, stopping."); return