single_shot: stream weights per-layer from CPU, fix KV/RoPE logic

This commit is contained in:
2026-05-31 02:53:40 +00:00
parent 61160ace13
commit 0c3d168c60

View File

@@ -343,12 +343,14 @@ class SimpleKVCache:
# Weight loading — streams safetensors shards, distributes to 8 GPUs
# =====================================================================
def load_all_weights(checkpoint_dir, num_layers):
"""Load all weights from checkpoint, distribute layers across GPUs.
def load_weights_to_cpu(checkpoint_dir):
"""Load all weights from checkpoint to CPU memory.
Weights stay on CPU; we move per-layer to GPU on demand during inference.
This avoids OOM from 285K GPU allocations and allows streaming.
Returns:
layer_weights: dict[li] → {key: tensor on cuda:li%8}
global_weights: {key: tensor on cuda:0}
all_weights: dict[key] → tensor on CPU
"""
from safetensors.torch import load_file
cdir = Path(checkpoint_dir)
@@ -360,7 +362,7 @@ def load_all_weights(checkpoint_dir, num_layers):
shard_names = set(weight_map.values()) if weight_map else {
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
}
print(f"Loading {len(shard_names)} shards...")
print(f"Loading {len(shard_names)} shards to CPU...")
all_weights = {}
loaded = 0
for shard_name in sorted(shard_names):
@@ -371,31 +373,21 @@ def load_all_weights(checkpoint_dir, num_layers):
loaded += 1
if loaded % 20 == 0:
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
print(f" Done: {len(all_weights)} tensors")
print(f" Done: {len(all_weights)} tensors on CPU")
return all_weights
layer_weights = {}
global_weights = {}
print("Assigning to GPUs...")
for key, tensor in all_weights.items():
if key.startswith("model.layers."):
li = int(key.split(".")[2])
target_gpu = li % NUM_GPUS
target_device = f"cuda:{target_gpu}"
if li not in layer_weights:
layer_weights[li] = {"_device": target_device, "_gpu": target_gpu}
layer_weights[li][key] = tensor.to(target_device)
elif key.startswith("model.embed_tokens"):
global_weights[key] = tensor.to("cuda:0")
elif key.startswith("model.norm"):
global_weights[key] = tensor.to("cuda:0")
elif key.startswith("lm_head"):
global_weights[key] = tensor.to("cuda:0")
for gpu in range(NUM_GPUS):
alloc = torch.cuda.memory_allocated(gpu) / 1e9
if alloc > 0:
print(f" GPU {gpu}: {alloc:.1f}GB")
return layer_weights, global_weights
def get_layer_weights(all_weights, li, device):
"""Get weights for layer li, moved to target device.
Returns dict of key→tensor on device. Filters by model.layers.{li} prefix.
"""
prefix = f"model.layers.{li}."
w = {}
for key in all_weights:
if key.startswith(prefix):
w[key] = all_weights[key].to(device=device, non_blocking=True)
return w
# =====================================================================
@@ -665,14 +657,14 @@ def main():
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
# ==== Phase 1: Load weights ====
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers)
# ==== Phase 1: Load weights to CPU ====
print(f"\n{'='*70}\nPhase 1: Loading weights to CPU\n{'='*70}")
all_weights = load_weights_to_cpu(CHECKPOINT_DIR)
t_loaded = time.time()
print(f"Weight loading: {t_loaded - t_start:.1f}s")
# ==== Build mHC blocks (proper Sinkhorn) ====
print("Building mHC blocks...")
# ==== Build mHC blocks + RMSNorms (small weights, keep on GPU) ====
print("Building mHC blocks and RMSNorms...")
attn_mhc_blocks = {}
ffn_mhc_blocks = {}
attn_norms = {}
@@ -681,18 +673,18 @@ def main():
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
# mHC blocks
# mHC blocks (small weights: fn (24, 28672) FP32 ≈ 2.6MB each)
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
fn = layer_weights[li].get(f"{prefix}.fn")
base = layer_weights[li].get(f"{prefix}.base")
scale = layer_weights[li].get(f"{prefix}.scale")
if fn is not None and base is not None and scale is not None:
fn_key = f"{prefix}.fn"
base_key = f"{prefix}.base"
scale_key = f"{prefix}.scale"
if fn_key in all_weights and base_key in all_weights and scale_key in all_weights:
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
mhc.load_from_checkpoint(fn, base, scale)
mhc.load_from_checkpoint(
all_weights[fn_key], all_weights[base_key], all_weights[scale_key])
blocks[li] = mhc
else:
# Fallback: identity mHC (A=1, B=I, C=1) — not ideal but prevents crash
print(f" WARNING: no mHC weights for {prefix}, using identity fallback")
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
n = n_hc
@@ -706,27 +698,29 @@ def main():
mhc.alpha_post = 0.01
blocks[li] = mhc
# RMSNorms (pre-norm before each sub-block)
# RMSNorms
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
an_key = f"model.layers.{li}.input_layernorm.weight"
if an_key in layer_weights[li]:
attn_norm.weight = layer_weights[li][an_key].to(device=dev, dtype=torch.float32)
if an_key in all_weights:
attn_norm.weight = all_weights[an_key].to(device=dev, dtype=torch.float32)
attn_norms[li] = attn_norm
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
fn_key = f"model.layers.{li}.post_attention_layernorm.weight"
if fn_key in layer_weights[li]:
ffn_norm.weight = layer_weights[li][fn_key].to(device=dev, dtype=torch.float32)
if fn_key in all_weights:
ffn_norm.weight = all_weights[fn_key].to(device=dev, dtype=torch.float32)
ffn_norms[li] = ffn_norm
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
# ==== Global weights ====
# ==== Global weights (small, keep on gpu0) ====
torch.cuda.set_device(0)
embed_w = global_weights.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
final_norm_w = global_weights.get("model.norm.weight")
embed_w = all_weights.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
final_norm_w = all_weights.get("model.norm.weight")
if final_norm_w is not None:
final_norm_w = final_norm_w.to('cuda:0')
rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
# ==== KV caches (one per layer on its GPU) ====
@@ -760,6 +754,7 @@ def main():
# ==== Prefill: process prompt tokens to fill KV cache ====
print(f"Prefilling {len(generated)} prompt tokens...")
for prefill_idx, tid_val in enumerate(generated):
t0 = time.time()
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0')
emb = embed(tid) # (1, H) on gpu0
@@ -767,22 +762,29 @@ def main():
for li in range(n_layers):
gpu = li % NUM_GPUS
target_device = f"cuda:{gpu}"
if X.device != torch.device(target_device):
X = X.to(target_device)
dev = f"cuda:{gpu}"
if X.device != torch.device(dev):
X = X.to(dev)
torch.cuda.set_device(gpu)
# Fetch this layer's weights from CPU → GPU (streamed, not all at once)
w = get_layer_weights(all_weights, li, dev)
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
a_norm = attn_norms[li]
f_norm = ffn_norms[li]
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
X = forward_layer(X, w, li, cfg, rc, rs,
attn_mhc, ffn_mhc, a_norm, f_norm,
kv_caches[li], tid, positions)
# Free per-layer GPU weights to save memory
del w
X = X.to('cuda:0')
torch.cuda.set_device(0)
if prefill_idx == 0:
print(f" Token 0: {time.time()-t0:.1f}s (includes per-layer weight transfer)")
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
@@ -801,19 +803,22 @@ def main():
for li in range(n_layers):
gpu = li % NUM_GPUS
target_device = f"cuda:{gpu}"
if X.device != torch.device(target_device):
X = X.to(target_device)
dev = f"cuda:{gpu}"
if X.device != torch.device(dev):
X = X.to(dev)
torch.cuda.set_device(gpu)
w = get_layer_weights(all_weights, li, dev)
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
a_norm = attn_norms[li]
f_norm = ffn_norms[li]
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
X = forward_layer(X, w, li, cfg, rc, rs,
attn_mhc, ffn_mhc, a_norm, f_norm,
kv_caches[li], tid, positions)
del w
X = X.to('cuda:0')
torch.cuda.set_device(0)