Cache layer weights on GPU — eliminates per-token CPU→GPU transfer
Previously, each prefill/decode token re-transferred ALL layer weights from CPU to GPU (66 tokens × 61 layers = 4026 transfers). This made prefill ~36s/token and caused the test to appear stuck. Now: one-time cache_all_layer_weights() loads all 61 layers to their target GPUs. Prefill should be ~1-2s/token instead of ~36s. Also added flush=True to print statements so progress is visible.
This commit is contained in:
@@ -316,6 +316,24 @@ def get_layer_weights(all_weights, li, device):
|
||||
return w
|
||||
|
||||
|
||||
def cache_all_layer_weights(all_weights, n_layers, devices):
|
||||
"""Pre-load ALL layer weights to their target GPUs.
|
||||
|
||||
This avoids the per-token CPU→GPU transfer bottleneck. Each layer's
|
||||
weights stay on its target GPU for the entire inference run.
|
||||
"""
|
||||
print(f" Caching layer weights to GPUs...")
|
||||
cached = {}
|
||||
for li in range(n_layers):
|
||||
gpu = li % len(devices)
|
||||
dev = devices[gpu]
|
||||
cached[li] = get_layer_weights(all_weights, li, dev)
|
||||
if (li + 1) % 10 == 0:
|
||||
print(f" {li+1}/{n_layers} layers cached")
|
||||
print(f" All {n_layers} layers cached to GPUs")
|
||||
return cached
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Single layer forward
|
||||
# =====================================================================
|
||||
@@ -733,6 +751,14 @@ def main():
|
||||
for li in range(n_layers):
|
||||
kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
|
||||
|
||||
# ==== Cache ALL layer weights to GPUs (avoids per-token CPU→GPU transfer) ====
|
||||
print(f"\n Caching layer weights to GPUs (one-time transfer)...", flush=True)
|
||||
devices = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||||
layer_weights = cache_all_layer_weights(all_weights, n_layers, devices)
|
||||
print(f" Done. Freeing CPU weights...", flush=True)
|
||||
del all_weights
|
||||
import gc; gc.collect()
|
||||
|
||||
# ==== Phase 2: Compile FMHA ====
|
||||
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
@@ -751,7 +777,7 @@ def main():
|
||||
print(f"\n{'='*70}\nPhase 2.5: Minimal E2E Test (single token 'The')\n{'='*70}")
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks,
|
||||
minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks,
|
||||
ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w,
|
||||
final_norm_w, tokenizer)
|
||||
|
||||
@@ -784,7 +810,7 @@ def main():
|
||||
generated = input_ids[0].tolist()
|
||||
|
||||
# ==== Prefill: process prompt tokens to fill KV cache ====
|
||||
print(f"Prefilling {len(generated)} prompt tokens...")
|
||||
print(f"Prefilling {len(generated)} prompt tokens...", flush=True)
|
||||
for prefill_idx, tid_val in enumerate(generated):
|
||||
t0 = time.time()
|
||||
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
||||
@@ -799,8 +825,7 @@ def main():
|
||||
X = X.to(dev)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# Fetch this layer's weights from CPU → GPU (streamed, not all at once)
|
||||
w = get_layer_weights(all_weights, li, dev)
|
||||
w = layer_weights[li]
|
||||
|
||||
attn_mhc = attn_mhc_blocks.get(li)
|
||||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||||
@@ -810,13 +835,11 @@ def main():
|
||||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||||
kv_caches[li], tid, positions)
|
||||
# Free per-layer GPU weights to save memory
|
||||
del w
|
||||
|
||||
X = X.to('cuda:0')
|
||||
torch.cuda.set_device(0)
|
||||
if prefill_idx == 0:
|
||||
print(f" Token 0: {time.time()-t0:.1f}s (includes per-layer weight transfer)")
|
||||
if prefill_idx % 10 == 0:
|
||||
print(f" Token {prefill_idx}/{len(generated)}: {time.time()-t0:.2f}s", flush=True)
|
||||
|
||||
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
|
||||
|
||||
@@ -840,7 +863,7 @@ def main():
|
||||
X = X.to(dev)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
w = get_layer_weights(all_weights, li, dev)
|
||||
w = layer_weights[li]
|
||||
|
||||
attn_mhc = attn_mhc_blocks.get(li)
|
||||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||||
@@ -850,7 +873,6 @@ def main():
|
||||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||||
kv_caches[li], tid, positions)
|
||||
del w
|
||||
|
||||
X = X.to('cuda:0')
|
||||
torch.cuda.set_device(0)
|
||||
@@ -878,7 +900,7 @@ def main():
|
||||
x_max = X.abs().max().item()
|
||||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) "
|
||||
f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} "
|
||||
f"|X|={x_max:.3f} top5: {top5_str}")
|
||||
f"|X|={x_max:.3f} top5: {top5_str}", flush=True)
|
||||
|
||||
if has_nan or has_inf:
|
||||
print(" Numerical issue — stopping")
|
||||
@@ -899,7 +921,7 @@ def main():
|
||||
# Minimal end-to-end test — single token "The" through the model
|
||||
# =====================================================================
|
||||
|
||||
def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks,
|
||||
def minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks,
|
||||
ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w,
|
||||
final_norm_w, tokenizer):
|
||||
"""Process a single token 'The' through the model and check output logits.
|
||||
@@ -942,7 +964,8 @@ def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks,
|
||||
X = X.to(dev)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
w = get_layer_weights(all_weights, li, dev)
|
||||
w = layer_weights[li]
|
||||
|
||||
attn_mhc = attn_mhc_blocks.get(li)
|
||||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||||
a_norm = attn_norms[li]
|
||||
@@ -968,8 +991,6 @@ def minimal_e2e_test(all_weights, cfg, rope_caches, attn_mhc_blocks,
|
||||
'nan': has_nan, 'inf': has_inf
|
||||
})
|
||||
|
||||
del w
|
||||
|
||||
if has_nan or has_inf:
|
||||
print(f" ❌ Layer {li}: NaN={has_nan} Inf={has_inf} — STOPPING")
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user