From f86742ef8ea681aa681f8f415582210d20ee2be1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 10:28:25 +0000 Subject: [PATCH] =?UTF-8?q?Cache=20layer=20weights=20on=20GPU=20=E2=80=94?= =?UTF-8?q?=20eliminates=20per-token=20CPU=E2=86=92GPU=20transfer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, each prefill/decode token re-transferred ALL layer weights from CPU to GPU (66 tokens × 61 layers = 4026 transfers). This made prefill ~36s/token and caused the test to appear stuck. Now: one-time cache_all_layer_weights() loads all 61 layers to their target GPUs. Prefill should be ~1-2s/token instead of ~36s. Also added flush=True to print statements so progress is visible. --- single_shot_inference.py | 51 ++++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 66bcfc12..5a68ebac 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -316,6 +316,24 @@ def get_layer_weights(all_weights, li, device): return w +def cache_all_layer_weights(all_weights, n_layers, devices): + """Pre-load ALL layer weights to their target GPUs. + + This avoids the per-token CPU→GPU transfer bottleneck. Each layer's + weights stay on its target GPU for the entire inference run. + """ + print(f" Caching layer weights to GPUs...") + cached = {} + for li in range(n_layers): + gpu = li % len(devices) + dev = devices[gpu] + cached[li] = get_layer_weights(all_weights, li, dev) + if (li + 1) % 10 == 0: + print(f" {li+1}/{n_layers} layers cached") + print(f" All {n_layers} layers cached to GPUs") + return cached + + # ===================================================================== # Single layer forward # ===================================================================== @@ -733,6 +751,14 @@ def main(): for li in range(n_layers): kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}") + # ==== Cache ALL layer weights to GPUs (avoids per-token CPU→GPU transfer) ==== + print(f"\n Caching layer weights to GPUs (one-time transfer)...", flush=True) + devices = [f"cuda:{g}" for g in range(NUM_GPUS)] + layer_weights = cache_all_layer_weights(all_weights, n_layers, devices) + print(f" Done. Freeing CPU weights...", flush=True) + del all_weights + import gc; gc.collect() + # ==== Phase 2: Compile FMHA ==== print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}") from dsv4.kernels.attention.production import dsv4_attention @@ -751,7 +777,7 @@ def main(): print(f"\n{'='*70}\nPhase 2.5: Minimal E2E Test (single token 'The')\n{'='*70}") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) - minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks, + minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks, ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w, final_norm_w, tokenizer) @@ -784,7 +810,7 @@ def main(): generated = input_ids[0].tolist() # ==== Prefill: process prompt tokens to fill KV cache ==== - print(f"Prefilling {len(generated)} prompt tokens...") + print(f"Prefilling {len(generated)} prompt tokens...", flush=True) for prefill_idx, tid_val in enumerate(generated): t0 = time.time() tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') @@ -799,8 +825,7 @@ def main(): X = X.to(dev) torch.cuda.set_device(gpu) - # Fetch this layer's weights from CPU → GPU (streamed, not all at once) - w = get_layer_weights(all_weights, li, dev) + w = layer_weights[li] attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) @@ -810,13 +835,11 @@ def main(): X = forward_layer(X, w, li, cfg, rc, rs, attn_mhc, ffn_mhc, a_norm, f_norm, kv_caches[li], tid, positions) - # Free per-layer GPU weights to save memory - del w X = X.to('cuda:0') torch.cuda.set_device(0) - if prefill_idx == 0: - print(f" Token 0: {time.time()-t0:.1f}s (includes per-layer weight transfer)") + if prefill_idx % 10 == 0: + print(f" Token {prefill_idx}/{len(generated)}: {time.time()-t0:.2f}s", flush=True) print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)") @@ -840,7 +863,7 @@ def main(): X = X.to(dev) torch.cuda.set_device(gpu) - w = get_layer_weights(all_weights, li, dev) + w = layer_weights[li] attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) @@ -850,7 +873,6 @@ def main(): X = forward_layer(X, w, li, cfg, rc, rs, attn_mhc, ffn_mhc, a_norm, f_norm, kv_caches[li], tid, positions) - del w X = X.to('cuda:0') torch.cuda.set_device(0) @@ -878,7 +900,7 @@ def main(): x_max = X.abs().max().item() print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) " f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} " - f"|X|={x_max:.3f} top5: {top5_str}") + f"|X|={x_max:.3f} top5: {top5_str}", flush=True) if has_nan or has_inf: print(" Numerical issue — stopping") @@ -899,7 +921,7 @@ def main(): # Minimal end-to-end test — single token "The" through the model # ===================================================================== -def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks, +def minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks, ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w, final_norm_w, tokenizer): """Process a single token 'The' through the model and check output logits. @@ -942,7 +964,8 @@ def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks, X = X.to(dev) torch.cuda.set_device(gpu) - w = get_layer_weights(all_weights, li, dev) + w = layer_weights[li] + attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) a_norm = attn_norms[li] @@ -968,8 +991,6 @@ def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks, 'nan': has_nan, 'inf': has_inf }) - del w - if has_nan or has_inf: print(f" ❌ Layer {li}: NaN={has_nan} Inf={has_inf} — STOPPING") break