diff --git a/FINAL_STRETCH.md b/FINAL_STRETCH.md index 602eb6ff..7a01ea07 100644 --- a/FINAL_STRETCH.md +++ b/FINAL_STRETCH.md @@ -91,6 +91,6 @@ Native Blackwell FP8 GEMM via tcgen05 for CSA Lightning Indexer scoring. No PyTo # PART D — Dangling TODOS - - Batched Prefill: ✅ DONE (T=1..128, mixed FP8/BF16 kernel) - - Need to wire prefill into single_shot_inference.py (replace T=1 token-by-token prefill) - - Need T>128 support (split into multiple launches) + - Batched Prefill: ✅ DONE (T=1..128, mixed FP8/BF16 kernel, chunked for T>128) + - Prefill wired into single_shot_inference.py: ✅ DONE (chunked batched prefill replaces T=1 token-by-token) + - T>128 support: ✅ DONE (splits into multiple launches of ≤128 tokens each) diff --git a/single_shot_inference.py b/single_shot_inference.py index ef86d48b..19b0d442 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -894,9 +894,10 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, # 6. Production FMHA — B1 mixed FP8/BF16 decode path. _pt('fmha_start') if li == 0: - print(f" L0 B1 verify: kv_nope_fp8 dtype={kv_nope_fp8.dtype} shape={tuple(kv_nope_fp8.shape)} " - f"kv_nope_scale dtype={kv_nope_scale.dtype} shape={tuple(kv_nope_scale.shape)} " - f"kv_rope_bf16 dtype={kv_rope_bf16.dtype} shape={tuple(kv_rope_bf16.shape)}", flush=True) + if VERBOSE >= 2: + print(f" L0 B1 verify: kv_nope_fp8 dtype={kv_nope_fp8.dtype} shape={tuple(kv_nope_fp8.shape)} " + f"kv_nope_scale dtype={kv_nope_scale.dtype} shape={tuple(kv_nope_scale.shape)} " + f"kv_rope_bf16 dtype={kv_rope_bf16.dtype} shape={tuple(kv_rope_bf16.shape)}", flush=True) assert kv_nope_fp8.dtype in (torch.uint8, torch.float8_e4m3fn), f"kv_nope_fp8 wrong dtype: {kv_nope_fp8.dtype}" assert kv_nope_scale.dtype == torch.float32, f"kv_nope_scale wrong dtype: {kv_nope_scale.dtype}" assert kv_rope_bf16.dtype == torch.bfloat16, f"kv_rope_bf16 wrong dtype: {kv_rope_bf16.dtype}"