From 0c3d168c60e59cf36fa45d5751da8b830454fa16 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 02:53:40 +0000 Subject: [PATCH] single_shot: stream weights per-layer from CPU, fix KV/RoPE logic --- single_shot_inference.py | 121 ++++++++++++++++++++------------------- 1 file changed, 63 insertions(+), 58 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 2b7faba7..18e3e945 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -343,12 +343,14 @@ class SimpleKVCache: # Weight loading — streams safetensors shards, distributes to 8 GPUs # ===================================================================== -def load_all_weights(checkpoint_dir, num_layers): - """Load all weights from checkpoint, distribute layers across GPUs. +def load_weights_to_cpu(checkpoint_dir): + """Load all weights from checkpoint to CPU memory. + + Weights stay on CPU; we move per-layer to GPU on demand during inference. + This avoids OOM from 285K GPU allocations and allows streaming. Returns: - layer_weights: dict[li] → {key: tensor on cuda:li%8} - global_weights: {key: tensor on cuda:0} + all_weights: dict[key] → tensor on CPU """ from safetensors.torch import load_file cdir = Path(checkpoint_dir) @@ -360,7 +362,7 @@ def load_all_weights(checkpoint_dir, num_layers): shard_names = set(weight_map.values()) if weight_map else { f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96) } - print(f"Loading {len(shard_names)} shards...") + print(f"Loading {len(shard_names)} shards to CPU...") all_weights = {} loaded = 0 for shard_name in sorted(shard_names): @@ -371,31 +373,21 @@ def load_all_weights(checkpoint_dir, num_layers): loaded += 1 if loaded % 20 == 0: print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors") - print(f" Done: {len(all_weights)} tensors") + print(f" Done: {len(all_weights)} tensors on CPU") + return all_weights - layer_weights = {} - global_weights = {} - print("Assigning to GPUs...") - for key, tensor in all_weights.items(): - if key.startswith("model.layers."): - li = int(key.split(".")[2]) - target_gpu = li % NUM_GPUS - target_device = f"cuda:{target_gpu}" - if li not in layer_weights: - layer_weights[li] = {"_device": target_device, "_gpu": target_gpu} - layer_weights[li][key] = tensor.to(target_device) - elif key.startswith("model.embed_tokens"): - global_weights[key] = tensor.to("cuda:0") - elif key.startswith("model.norm"): - global_weights[key] = tensor.to("cuda:0") - elif key.startswith("lm_head"): - global_weights[key] = tensor.to("cuda:0") - for gpu in range(NUM_GPUS): - alloc = torch.cuda.memory_allocated(gpu) / 1e9 - if alloc > 0: - print(f" GPU {gpu}: {alloc:.1f}GB") - return layer_weights, global_weights +def get_layer_weights(all_weights, li, device): + """Get weights for layer li, moved to target device. + + Returns dict of key→tensor on device. Filters by model.layers.{li} prefix. + """ + prefix = f"model.layers.{li}." + w = {} + for key in all_weights: + if key.startswith(prefix): + w[key] = all_weights[key].to(device=device, non_blocking=True) + return w # ===================================================================== @@ -665,14 +657,14 @@ def main(): print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}") print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}") - # ==== Phase 1: Load weights ==== - print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}") - layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers) + # ==== Phase 1: Load weights to CPU ==== + print(f"\n{'='*70}\nPhase 1: Loading weights to CPU\n{'='*70}") + all_weights = load_weights_to_cpu(CHECKPOINT_DIR) t_loaded = time.time() print(f"Weight loading: {t_loaded - t_start:.1f}s") - # ==== Build mHC blocks (proper Sinkhorn) ==== - print("Building mHC blocks...") + # ==== Build mHC blocks + RMSNorms (small weights, keep on GPU) ==== + print("Building mHC blocks and RMSNorms...") attn_mhc_blocks = {} ffn_mhc_blocks = {} attn_norms = {} @@ -681,18 +673,18 @@ def main(): gpu = li % NUM_GPUS dev = f"cuda:{gpu}" - # mHC blocks + # mHC blocks (small weights: fn (24, 28672) FP32 ≈ 2.6MB each) for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks), (f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]: - fn = layer_weights[li].get(f"{prefix}.fn") - base = layer_weights[li].get(f"{prefix}.base") - scale = layer_weights[li].get(f"{prefix}.scale") - if fn is not None and base is not None and scale is not None: + fn_key = f"{prefix}.fn" + base_key = f"{prefix}.base" + scale_key = f"{prefix}.scale" + if fn_key in all_weights and base_key in all_weights and scale_key in all_weights: mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) - mhc.load_from_checkpoint(fn, base, scale) + mhc.load_from_checkpoint( + all_weights[fn_key], all_weights[base_key], all_weights[scale_key]) blocks[li] = mhc else: - # Fallback: identity mHC (A=1, B=I, C=1) — not ideal but prevents crash print(f" WARNING: no mHC weights for {prefix}, using identity fallback") mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) n = n_hc @@ -706,27 +698,29 @@ def main(): mhc.alpha_post = 0.01 blocks[li] = mhc - # RMSNorms (pre-norm before each sub-block) + # RMSNorms attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev) an_key = f"model.layers.{li}.input_layernorm.weight" - if an_key in layer_weights[li]: - attn_norm.weight = layer_weights[li][an_key].to(device=dev, dtype=torch.float32) + if an_key in all_weights: + attn_norm.weight = all_weights[an_key].to(device=dev, dtype=torch.float32) attn_norms[li] = attn_norm ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev) fn_key = f"model.layers.{li}.post_attention_layernorm.weight" - if fn_key in layer_weights[li]: - ffn_norm.weight = layer_weights[li][fn_key].to(device=dev, dtype=torch.float32) + if fn_key in all_weights: + ffn_norm.weight = all_weights[fn_key].to(device=dev, dtype=torch.float32) ffn_norms[li] = ffn_norm print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}") - # ==== Global weights ==== + # ==== Global weights (small, keep on gpu0) ==== torch.cuda.set_device(0) - embed_w = global_weights.get("model.embed_tokens.weight") - embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16()) - lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16() - final_norm_w = global_weights.get("model.norm.weight") + embed_w = all_weights.get("model.embed_tokens.weight") + embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0')) + lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to('cuda:0') + final_norm_w = all_weights.get("model.norm.weight") + if final_norm_w is not None: + final_norm_w = final_norm_w.to('cuda:0') rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)} # ==== KV caches (one per layer on its GPU) ==== @@ -760,6 +754,7 @@ def main(): # ==== Prefill: process prompt tokens to fill KV cache ==== print(f"Prefilling {len(generated)} prompt tokens...") for prefill_idx, tid_val in enumerate(generated): + t0 = time.time() tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0') emb = embed(tid) # (1, H) on gpu0 @@ -767,22 +762,29 @@ def main(): for li in range(n_layers): gpu = li % NUM_GPUS - target_device = f"cuda:{gpu}" - if X.device != torch.device(target_device): - X = X.to(target_device) + dev = f"cuda:{gpu}" + if X.device != torch.device(dev): + X = X.to(dev) torch.cuda.set_device(gpu) + # Fetch this layer's weights from CPU → GPU (streamed, not all at once) + w = get_layer_weights(all_weights, li, dev) + attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) a_norm = attn_norms[li] f_norm = ffn_norms[li] rc, rs = rope_caches[gpu] - X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, + X = forward_layer(X, w, li, cfg, rc, rs, attn_mhc, ffn_mhc, a_norm, f_norm, kv_caches[li], tid, positions) + # Free per-layer GPU weights to save memory + del w X = X.to('cuda:0') torch.cuda.set_device(0) + if prefill_idx == 0: + print(f" Token 0: {time.time()-t0:.1f}s (includes per-layer weight transfer)") print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)") @@ -801,19 +803,22 @@ def main(): for li in range(n_layers): gpu = li % NUM_GPUS - target_device = f"cuda:{gpu}" - if X.device != torch.device(target_device): - X = X.to(target_device) + dev = f"cuda:{gpu}" + if X.device != torch.device(dev): + X = X.to(dev) torch.cuda.set_device(gpu) + w = get_layer_weights(all_weights, li, dev) + attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) a_norm = attn_norms[li] f_norm = ffn_norms[li] rc, rs = rope_caches[gpu] - X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, + X = forward_layer(X, w, li, cfg, rc, rs, attn_mhc, ffn_mhc, a_norm, f_norm, kv_caches[li], tid, positions) + del w X = X.to('cuda:0') torch.cuda.set_device(0)