single_shot: stream weights per-layer from CPU, fix KV/RoPE logic
This commit is contained in:
@@ -343,12 +343,14 @@ class SimpleKVCache:
|
||||
# Weight loading — streams safetensors shards, distributes to 8 GPUs
|
||||
# =====================================================================
|
||||
|
||||
def load_all_weights(checkpoint_dir, num_layers):
|
||||
"""Load all weights from checkpoint, distribute layers across GPUs.
|
||||
def load_weights_to_cpu(checkpoint_dir):
|
||||
"""Load all weights from checkpoint to CPU memory.
|
||||
|
||||
Weights stay on CPU; we move per-layer to GPU on demand during inference.
|
||||
This avoids OOM from 285K GPU allocations and allows streaming.
|
||||
|
||||
Returns:
|
||||
layer_weights: dict[li] → {key: tensor on cuda:li%8}
|
||||
global_weights: {key: tensor on cuda:0}
|
||||
all_weights: dict[key] → tensor on CPU
|
||||
"""
|
||||
from safetensors.torch import load_file
|
||||
cdir = Path(checkpoint_dir)
|
||||
@@ -360,7 +362,7 @@ def load_all_weights(checkpoint_dir, num_layers):
|
||||
shard_names = set(weight_map.values()) if weight_map else {
|
||||
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
||||
}
|
||||
print(f"Loading {len(shard_names)} shards...")
|
||||
print(f"Loading {len(shard_names)} shards to CPU...")
|
||||
all_weights = {}
|
||||
loaded = 0
|
||||
for shard_name in sorted(shard_names):
|
||||
@@ -371,31 +373,21 @@ def load_all_weights(checkpoint_dir, num_layers):
|
||||
loaded += 1
|
||||
if loaded % 20 == 0:
|
||||
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
|
||||
print(f" Done: {len(all_weights)} tensors")
|
||||
print(f" Done: {len(all_weights)} tensors on CPU")
|
||||
return all_weights
|
||||
|
||||
layer_weights = {}
|
||||
global_weights = {}
|
||||
print("Assigning to GPUs...")
|
||||
for key, tensor in all_weights.items():
|
||||
if key.startswith("model.layers."):
|
||||
li = int(key.split(".")[2])
|
||||
target_gpu = li % NUM_GPUS
|
||||
target_device = f"cuda:{target_gpu}"
|
||||
if li not in layer_weights:
|
||||
layer_weights[li] = {"_device": target_device, "_gpu": target_gpu}
|
||||
layer_weights[li][key] = tensor.to(target_device)
|
||||
elif key.startswith("model.embed_tokens"):
|
||||
global_weights[key] = tensor.to("cuda:0")
|
||||
elif key.startswith("model.norm"):
|
||||
global_weights[key] = tensor.to("cuda:0")
|
||||
elif key.startswith("lm_head"):
|
||||
global_weights[key] = tensor.to("cuda:0")
|
||||
|
||||
for gpu in range(NUM_GPUS):
|
||||
alloc = torch.cuda.memory_allocated(gpu) / 1e9
|
||||
if alloc > 0:
|
||||
print(f" GPU {gpu}: {alloc:.1f}GB")
|
||||
return layer_weights, global_weights
|
||||
def get_layer_weights(all_weights, li, device):
|
||||
"""Get weights for layer li, moved to target device.
|
||||
|
||||
Returns dict of key→tensor on device. Filters by model.layers.{li} prefix.
|
||||
"""
|
||||
prefix = f"model.layers.{li}."
|
||||
w = {}
|
||||
for key in all_weights:
|
||||
if key.startswith(prefix):
|
||||
w[key] = all_weights[key].to(device=device, non_blocking=True)
|
||||
return w
|
||||
|
||||
|
||||
# =====================================================================
|
||||
@@ -665,14 +657,14 @@ def main():
|
||||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||||
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
|
||||
|
||||
# ==== Phase 1: Load weights ====
|
||||
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
|
||||
layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers)
|
||||
# ==== Phase 1: Load weights to CPU ====
|
||||
print(f"\n{'='*70}\nPhase 1: Loading weights to CPU\n{'='*70}")
|
||||
all_weights = load_weights_to_cpu(CHECKPOINT_DIR)
|
||||
t_loaded = time.time()
|
||||
print(f"Weight loading: {t_loaded - t_start:.1f}s")
|
||||
|
||||
# ==== Build mHC blocks (proper Sinkhorn) ====
|
||||
print("Building mHC blocks...")
|
||||
# ==== Build mHC blocks + RMSNorms (small weights, keep on GPU) ====
|
||||
print("Building mHC blocks and RMSNorms...")
|
||||
attn_mhc_blocks = {}
|
||||
ffn_mhc_blocks = {}
|
||||
attn_norms = {}
|
||||
@@ -681,18 +673,18 @@ def main():
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
|
||||
# mHC blocks
|
||||
# mHC blocks (small weights: fn (24, 28672) FP32 ≈ 2.6MB each)
|
||||
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
|
||||
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
|
||||
fn = layer_weights[li].get(f"{prefix}.fn")
|
||||
base = layer_weights[li].get(f"{prefix}.base")
|
||||
scale = layer_weights[li].get(f"{prefix}.scale")
|
||||
if fn is not None and base is not None and scale is not None:
|
||||
fn_key = f"{prefix}.fn"
|
||||
base_key = f"{prefix}.base"
|
||||
scale_key = f"{prefix}.scale"
|
||||
if fn_key in all_weights and base_key in all_weights and scale_key in all_weights:
|
||||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||||
mhc.load_from_checkpoint(fn, base, scale)
|
||||
mhc.load_from_checkpoint(
|
||||
all_weights[fn_key], all_weights[base_key], all_weights[scale_key])
|
||||
blocks[li] = mhc
|
||||
else:
|
||||
# Fallback: identity mHC (A=1, B=I, C=1) — not ideal but prevents crash
|
||||
print(f" WARNING: no mHC weights for {prefix}, using identity fallback")
|
||||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||||
n = n_hc
|
||||
@@ -706,27 +698,29 @@ def main():
|
||||
mhc.alpha_post = 0.01
|
||||
blocks[li] = mhc
|
||||
|
||||
# RMSNorms (pre-norm before each sub-block)
|
||||
# RMSNorms
|
||||
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
||||
an_key = f"model.layers.{li}.input_layernorm.weight"
|
||||
if an_key in layer_weights[li]:
|
||||
attn_norm.weight = layer_weights[li][an_key].to(device=dev, dtype=torch.float32)
|
||||
if an_key in all_weights:
|
||||
attn_norm.weight = all_weights[an_key].to(device=dev, dtype=torch.float32)
|
||||
attn_norms[li] = attn_norm
|
||||
|
||||
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
||||
fn_key = f"model.layers.{li}.post_attention_layernorm.weight"
|
||||
if fn_key in layer_weights[li]:
|
||||
ffn_norm.weight = layer_weights[li][fn_key].to(device=dev, dtype=torch.float32)
|
||||
if fn_key in all_weights:
|
||||
ffn_norm.weight = all_weights[fn_key].to(device=dev, dtype=torch.float32)
|
||||
ffn_norms[li] = ffn_norm
|
||||
|
||||
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
|
||||
|
||||
# ==== Global weights ====
|
||||
# ==== Global weights (small, keep on gpu0) ====
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = global_weights.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
|
||||
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
|
||||
final_norm_w = global_weights.get("model.norm.weight")
|
||||
embed_w = all_weights.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||||
lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||||
final_norm_w = all_weights.get("model.norm.weight")
|
||||
if final_norm_w is not None:
|
||||
final_norm_w = final_norm_w.to('cuda:0')
|
||||
rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
|
||||
|
||||
# ==== KV caches (one per layer on its GPU) ====
|
||||
@@ -760,6 +754,7 @@ def main():
|
||||
# ==== Prefill: process prompt tokens to fill KV cache ====
|
||||
print(f"Prefilling {len(generated)} prompt tokens...")
|
||||
for prefill_idx, tid_val in enumerate(generated):
|
||||
t0 = time.time()
|
||||
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
||||
positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0')
|
||||
emb = embed(tid) # (1, H) on gpu0
|
||||
@@ -767,22 +762,29 @@ def main():
|
||||
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
target_device = f"cuda:{gpu}"
|
||||
if X.device != torch.device(target_device):
|
||||
X = X.to(target_device)
|
||||
dev = f"cuda:{gpu}"
|
||||
if X.device != torch.device(dev):
|
||||
X = X.to(dev)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# Fetch this layer's weights from CPU → GPU (streamed, not all at once)
|
||||
w = get_layer_weights(all_weights, li, dev)
|
||||
|
||||
attn_mhc = attn_mhc_blocks.get(li)
|
||||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||||
a_norm = attn_norms[li]
|
||||
f_norm = ffn_norms[li]
|
||||
rc, rs = rope_caches[gpu]
|
||||
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
|
||||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||||
kv_caches[li], tid, positions)
|
||||
# Free per-layer GPU weights to save memory
|
||||
del w
|
||||
|
||||
X = X.to('cuda:0')
|
||||
torch.cuda.set_device(0)
|
||||
if prefill_idx == 0:
|
||||
print(f" Token 0: {time.time()-t0:.1f}s (includes per-layer weight transfer)")
|
||||
|
||||
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
|
||||
|
||||
@@ -801,19 +803,22 @@ def main():
|
||||
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
target_device = f"cuda:{gpu}"
|
||||
if X.device != torch.device(target_device):
|
||||
X = X.to(target_device)
|
||||
dev = f"cuda:{gpu}"
|
||||
if X.device != torch.device(dev):
|
||||
X = X.to(dev)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
w = get_layer_weights(all_weights, li, dev)
|
||||
|
||||
attn_mhc = attn_mhc_blocks.get(li)
|
||||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||||
a_norm = attn_norms[li]
|
||||
f_norm = ffn_norms[li]
|
||||
rc, rs = rope_caches[gpu]
|
||||
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
|
||||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||||
kv_caches[li], tid, positions)
|
||||
del w
|
||||
|
||||
X = X.to('cuda:0')
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
Reference in New Issue
Block a user