Cache layer weights on GPU — eliminates per-token CPU→GPU transfer

Previously, each prefill/decode token re-transferred ALL layer weights
from CPU to GPU (66 tokens × 61 layers = 4026 transfers). This made
prefill ~36s/token and caused the test to appear stuck.

Now: one-time cache_all_layer_weights() loads all 61 layers to their
target GPUs. Prefill should be ~1-2s/token instead of ~36s.

Also added flush=True to print statements so progress is visible.
This commit is contained in:
2026-05-31 10:28:25 +00:00
parent ce3d6069cc
commit f86742ef8e

View File

@@ -316,6 +316,24 @@ def get_layer_weights(all_weights, li, device):
return w
def cache_all_layer_weights(all_weights, n_layers, devices):
"""Pre-load ALL layer weights to their target GPUs.
This avoids the per-token CPU→GPU transfer bottleneck. Each layer's
weights stay on its target GPU for the entire inference run.
"""
print(f" Caching layer weights to GPUs...")
cached = {}
for li in range(n_layers):
gpu = li % len(devices)
dev = devices[gpu]
cached[li] = get_layer_weights(all_weights, li, dev)
if (li + 1) % 10 == 0:
print(f" {li+1}/{n_layers} layers cached")
print(f" All {n_layers} layers cached to GPUs")
return cached
# =====================================================================
# Single layer forward
# =====================================================================
@@ -733,6 +751,14 @@ def main():
for li in range(n_layers):
kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
# ==== Cache ALL layer weights to GPUs (avoids per-token CPU→GPU transfer) ====
print(f"\n Caching layer weights to GPUs (one-time transfer)...", flush=True)
devices = [f"cuda:{g}" for g in range(NUM_GPUS)]
layer_weights = cache_all_layer_weights(all_weights, n_layers, devices)
print(f" Done. Freeing CPU weights...", flush=True)
del all_weights
import gc; gc.collect()
# ==== Phase 2: Compile FMHA ====
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
from dsv4.kernels.attention.production import dsv4_attention
@@ -751,7 +777,7 @@ def main():
print(f"\n{'='*70}\nPhase 2.5: Minimal E2E Test (single token 'The')\n{'='*70}")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks,
minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks,
ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w,
final_norm_w, tokenizer)
@@ -784,7 +810,7 @@ def main():
generated = input_ids[0].tolist()
# ==== Prefill: process prompt tokens to fill KV cache ====
print(f"Prefilling {len(generated)} prompt tokens...")
print(f"Prefilling {len(generated)} prompt tokens...", flush=True)
for prefill_idx, tid_val in enumerate(generated):
t0 = time.time()
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
@@ -799,8 +825,7 @@ def main():
X = X.to(dev)
torch.cuda.set_device(gpu)
# Fetch this layer's weights from CPU → GPU (streamed, not all at once)
w = get_layer_weights(all_weights, li, dev)
w = layer_weights[li]
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
@@ -810,13 +835,11 @@ def main():
X = forward_layer(X, w, li, cfg, rc, rs,
attn_mhc, ffn_mhc, a_norm, f_norm,
kv_caches[li], tid, positions)
# Free per-layer GPU weights to save memory
del w
X = X.to('cuda:0')
torch.cuda.set_device(0)
if prefill_idx == 0:
print(f" Token 0: {time.time()-t0:.1f}s (includes per-layer weight transfer)")
if prefill_idx % 10 == 0:
print(f" Token {prefill_idx}/{len(generated)}: {time.time()-t0:.2f}s", flush=True)
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
@@ -840,7 +863,7 @@ def main():
X = X.to(dev)
torch.cuda.set_device(gpu)
w = get_layer_weights(all_weights, li, dev)
w = layer_weights[li]
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
@@ -850,7 +873,6 @@ def main():
X = forward_layer(X, w, li, cfg, rc, rs,
attn_mhc, ffn_mhc, a_norm, f_norm,
kv_caches[li], tid, positions)
del w
X = X.to('cuda:0')
torch.cuda.set_device(0)
@@ -878,7 +900,7 @@ def main():
x_max = X.abs().max().item()
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) "
f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} "
f"|X|={x_max:.3f} top5: {top5_str}")
f"|X|={x_max:.3f} top5: {top5_str}", flush=True)
if has_nan or has_inf:
print(" Numerical issue — stopping")
@@ -899,7 +921,7 @@ def main():
# Minimal end-to-end test — single token "The" through the model
# =====================================================================
def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks,
def minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks,
ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w,
final_norm_w, tokenizer):
"""Process a single token 'The' through the model and check output logits.
@@ -942,7 +964,8 @@ def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks,
X = X.to(dev)
torch.cuda.set_device(gpu)
w = get_layer_weights(all_weights, li, dev)
w = layer_weights[li]
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
a_norm = attn_norms[li]
@@ -968,8 +991,6 @@ def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks,
'nan': has_nan, 'inf': has_inf
})
del w
if has_nan or has_inf:
print(f" ❌ Layer {li}: NaN={has_nan} Inf={has_inf} — STOPPING")
break