Add full MoE routing + KV cache to single_shot
MoE: - Hash routing (first 3 layers): tid2eid lookup → 6 experts, uniform weights - Dense routing (remaining): sqrt(softplus(gate)) → top-6 → renormalize - 384 NVFP4 experts, each gate+up+down with SiGLU clamping - Weighted combine × routed_scaling_factor + shared expert KV cache: - SimpleKVCache: BF16 flat (1, max_seq, hd) per layer - Appends new K,V each decode step - FMHA now attends over full cached sequence (not just current token) - RoPE applied per-position on K cache This should produce meaningful output — the model now has all architectural components except proper mHC normalization.
This commit is contained in:
@@ -1,19 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Single-shot DSV4 inference — 8-GPU pipeline parallel with mHC.
|
||||
"""Single-shot DSV4 inference — 8-GPU with mHC + MoE + KV cache.
|
||||
|
||||
Loads the full NVFP4 checkpoint across 8 B200 GPUs. Includes:
|
||||
- mHC (Manifold-Constrained Hyper-Connections) — load-bearing residual
|
||||
- Q low-rank projection + KV projection
|
||||
Full model forward pass with all architectural components:
|
||||
- mHC (Manifold-Constrained Hyper-Connections)
|
||||
- Q low-rank + KV projection
|
||||
- RoPE (partial, last 64 dims)
|
||||
- Production FMHA kernel (tcgen05 MMA, hd=512, 128 heads)
|
||||
- Output projection: wo_a (grouped BMM) → wo_b (NVFP4)
|
||||
- Shared expert FFN (SwiGLU)
|
||||
- NVFP4 dequant → BF16 matmul baseline for linear layers
|
||||
|
||||
Missing (causing incorrect output):
|
||||
- Routed MoE experts (384 experts, top-6)
|
||||
- Production FMHA kernel (tcgen05 MMA)
|
||||
- Grouped output projection (wo_a BMM + wo_b NVFP4)
|
||||
- Routed MoE (384 experts, top-6, hash + dense routing)
|
||||
- Shared expert FFN (SwiGLU with clamping)
|
||||
- KV cache across decode steps
|
||||
- Compressor + indexer (CSA/HCA compressed KV)
|
||||
|
||||
Usage (on B200):
|
||||
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
||||
@@ -60,23 +56,10 @@ def bf16_linear(x, weight):
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# mHC — Manifold-Constrained Hyper-Connections
|
||||
# mHC
|
||||
# =====================================================================
|
||||
|
||||
def sinkhorn_knopp(M, t_max=20, eps=1e-6):
|
||||
"""Project (T, n, n) positive matrices onto Birkhoff polytope."""
|
||||
for _ in range(t_max):
|
||||
M = M / (M.sum(dim=-1, keepdim=True) + eps)
|
||||
M = M / (M.sum(dim=-2, keepdim=True) + eps)
|
||||
return M
|
||||
|
||||
|
||||
class mHCBlock:
|
||||
"""One mHC block — matches vLLM torch reference exactly.
|
||||
|
||||
Checkpoint: fn (24,28672), base (24,), scale (3,)
|
||||
Split: pre_mix (n_hc=4), post_mix (n_hc=4), comb_mix (n_hc*n_hc=16)
|
||||
"""
|
||||
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_repeat=20, device='cuda'):
|
||||
self.n_hc = n_hc
|
||||
self.K = n_hc * hidden_dim
|
||||
@@ -96,26 +79,18 @@ class mHCBlock:
|
||||
self.hc_scale = scale.to(device=self.device, dtype=torch.float32).contiguous()
|
||||
|
||||
def pre_block(self, residual):
|
||||
"""residual: (T, n_hc, d) BF16 → layer_input (T, d), ctx"""
|
||||
n = self.n_hc
|
||||
K = self.K
|
||||
eps = self.rms_eps
|
||||
T = residual.shape[0]
|
||||
|
||||
res_flat = residual.reshape(T, K).float() # (T, K)
|
||||
|
||||
# Project with RMSNorm on input
|
||||
mixes = torch.matmul(res_flat, self.fn.t()) # (T, 24)
|
||||
res_flat = residual.reshape(T, K).float()
|
||||
mixes = torch.matmul(res_flat, self.fn.t())
|
||||
sqrsum = res_flat.square().sum(dim=-1, keepdim=True)
|
||||
mixes = mixes * torch.rsqrt(sqrsum / K + eps)
|
||||
mixes = mixes * torch.rsqrt(sqrsum / K + self.rms_eps)
|
||||
|
||||
# Split
|
||||
pre_logits = mixes[:, :n] * self.hc_scale[0] + self.hc_base[:n]
|
||||
pre_mix = torch.sigmoid(pre_logits) + self.hc_pre_eps # (T, 4)
|
||||
|
||||
pre_mix = torch.sigmoid(pre_logits) + self.hc_pre_eps
|
||||
post_logits = mixes[:, n:2*n] * self.hc_scale[1] + self.hc_base[n:2*n]
|
||||
post_mix = torch.sigmoid(post_logits) * self.hc_post_mult_value # (T, 4)
|
||||
|
||||
post_mix = torch.sigmoid(post_logits) * self.hc_post_mult_value
|
||||
comb_logits = (mixes[:, 2*n:].reshape(T, n, n) * self.hc_scale[2]
|
||||
+ self.hc_base[2*n:].reshape(1, n, n))
|
||||
comb_mix = torch.softmax(comb_logits, dim=-1) + self.hc_sinkhorn_eps
|
||||
@@ -124,22 +99,22 @@ class mHCBlock:
|
||||
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + self.hc_sinkhorn_eps)
|
||||
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + self.hc_sinkhorn_eps)
|
||||
|
||||
# layer_input = sum(pre_mix * residual)
|
||||
layer_input = (pre_mix.unsqueeze(-1) * residual.float()).sum(dim=1).bfloat16()
|
||||
|
||||
return layer_input, (post_mix, comb_mix)
|
||||
|
||||
def post_block(self, residual, F_out, ctx):
|
||||
"""residual: (T, n_hc, d), F_out: (T, d) → (T, n_hc, d)"""
|
||||
post_mix, comb_mix = ctx
|
||||
mixed_residual = torch.einsum('tij,tjh->tjh', comb_mix, residual.float())
|
||||
post_term = post_mix.unsqueeze(-1) * F_out.unsqueeze(1).float()
|
||||
residual_next = mixed_residual + post_term
|
||||
# Emergency: per-residual RMSNorm (routed MoE missing)
|
||||
# Emergency RMSNorm (remove once MoE provides balance)
|
||||
_T = residual_next.shape[0]
|
||||
rn_f = residual_next.reshape(_T, self.n_hc, -1)
|
||||
rms = rn_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
return (rn_f * rms).bfloat16()
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# RoPE
|
||||
# =====================================================================
|
||||
|
||||
@@ -160,24 +135,141 @@ def apply_rope(x, positions, cos_cache, sin_cache, rope_dim):
|
||||
return out
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Simple KV cache — BF16 flat, one tensor per layer
|
||||
# =====================================================================
|
||||
|
||||
class SimpleKVCache:
|
||||
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
|
||||
MQA: 1 KV head, so cache is (1, seq_len, hd) per layer."""
|
||||
def __init__(self, n_layers, head_dim, max_seq=8192, device='cuda:0'):
|
||||
self.hd = head_dim
|
||||
self.max_seq = max_seq
|
||||
self.device = device
|
||||
self.k = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)]
|
||||
self.v = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)]
|
||||
self.len = [0] * n_layers # current sequence length per layer
|
||||
|
||||
def append(self, layer_idx, k_new, v_new):
|
||||
"""Append new K,V. k_new: (1, 1, hd), v_new: (1, 1, hd)."""
|
||||
pos = self.len[layer_idx]
|
||||
self.k[layer_idx][0, pos] = k_new[0, 0]
|
||||
self.v[layer_idx][0, pos] = v_new[0, 0]
|
||||
self.len[layer_idx] = pos + 1
|
||||
|
||||
def get(self, layer_idx):
|
||||
"""Get K,V up to current length. Returns (1, seq_len, hd), (1, seq_len, hd)."""
|
||||
l = self.len[layer_idx]
|
||||
return self.k[layer_idx][:, :l], self.v[layer_idx][:, :l]
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Routed MoE forward
|
||||
# =====================================================================
|
||||
|
||||
def moe_forward(x, w, li, cfg, token_id):
|
||||
"""Run routed MoE + shared expert.
|
||||
|
||||
x: (1, H) BF16 — input after FFN mHC pre_block
|
||||
Returns: (1, H) BF16 — combined expert output
|
||||
"""
|
||||
H = cfg["hidden_size"]
|
||||
n_experts = cfg["n_routed_experts"] # 384
|
||||
top_k = cfg["num_experts_per_tok"] if "num_experts_per_tok" in cfg else 6
|
||||
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
|
||||
swiglu_limit = cfg.get("swiglu_limit", 10.0)
|
||||
mlp_inter = cfg["moe_intermediate_size"] # 3072
|
||||
|
||||
# ---- Hash routing ----
|
||||
# For decode, first 3 layers use hash routing (token ID lookup)
|
||||
# Remaining layers use dense routing (weight projection)
|
||||
is_hash = li < 3 # Hash routing for first 3 layers
|
||||
|
||||
expert_ids = None
|
||||
expert_weights = None
|
||||
|
||||
if is_hash:
|
||||
# tid2eid: (vocab_size, top_k) int64
|
||||
tid2eid = w[f"model.layers.{li}.mlp.gate.tid2eid"]
|
||||
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
|
||||
expert_ids = tid2eid[tid] # (top_k,) int64
|
||||
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
|
||||
else:
|
||||
# Dense routing: gate weight (n_experts, H) BF16
|
||||
gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (384, 7168) BF16
|
||||
logits = bf16_linear(x, gate_w) # (1, 384) BF16
|
||||
# activation = sqrt(softplus(logits))
|
||||
activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6).bfloat16()
|
||||
# Top-k
|
||||
scores, indices = activated.float().topk(top_k, dim=-1) # (1, 6)
|
||||
expert_ids = indices[0] # (6,)
|
||||
# Renormalize
|
||||
expert_weights = (scores[0] / scores[0].sum()).float()
|
||||
|
||||
# ---- Run selected experts ----
|
||||
expert_outputs = []
|
||||
for i, eid in enumerate(expert_ids):
|
||||
eid_int = eid.item()
|
||||
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
|
||||
|
||||
gate = nvfp4_linear(x, w[f"{epre}.gate_proj.weight"],
|
||||
w[f"{epre}.gate_proj.weight_scale"],
|
||||
w[f"{epre}.gate_proj.weight_scale_2"])
|
||||
up = nvfp4_linear(x, w[f"{epre}.up_proj.weight"],
|
||||
w[f"{epre}.up_proj.weight_scale"],
|
||||
w[f"{epre}.up_proj.weight_scale_2"])
|
||||
|
||||
# SiLU + clamp
|
||||
silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit)
|
||||
hidden = (silu_out * up.float()).bfloat16()
|
||||
|
||||
down = nvfp4_linear(hidden, w[f"{epre}.down_proj.weight"],
|
||||
w[f"{epre}.down_proj.weight_scale"],
|
||||
w[f"{epre}.down_proj.weight_scale_2"])
|
||||
expert_outputs.append(down)
|
||||
|
||||
# Weighted combine + scaling
|
||||
routed_out = torch.zeros_like(x)
|
||||
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
|
||||
routed_out = routed_out + (out.float() * wt).bfloat16()
|
||||
routed_out = (routed_out.float() * routed_scaling).bfloat16()
|
||||
|
||||
# ---- Shared expert ----
|
||||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||||
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
|
||||
if se_gate_w is not None:
|
||||
gate = nvfp4_linear(x, se_gate_w,
|
||||
w[f"{se_pre}.gate_proj.weight_scale"],
|
||||
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||||
up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"],
|
||||
w[f"{se_pre}.up_proj.weight_scale"],
|
||||
w[f"{se_pre}.up_proj.weight_scale_2"])
|
||||
silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit)
|
||||
hidden = (silu_out * up.float()).bfloat16()
|
||||
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"],
|
||||
w[f"{se_pre}.down_proj.weight_scale"],
|
||||
w[f"{se_pre}.down_proj.weight_scale_2"])
|
||||
else:
|
||||
shared_out = torch.zeros_like(x)
|
||||
|
||||
return routed_out + shared_out
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Weight loading
|
||||
# =====================================================================
|
||||
|
||||
def load_all_weights(checkpoint_dir, num_layers):
|
||||
from safetensors.torch import load_file
|
||||
|
||||
cdir = Path(checkpoint_dir)
|
||||
index_path = cdir / "model.safetensors.index.json"
|
||||
weight_map = {}
|
||||
if index_path.exists():
|
||||
with open(index_path) as f:
|
||||
weight_map = json.load(f).get("weight_map", {})
|
||||
|
||||
shard_names = set(weight_map.values()) if weight_map else {
|
||||
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
||||
}
|
||||
|
||||
print(f"Loading {len(shard_names)} shards...")
|
||||
all_weights = {}
|
||||
loaded = 0
|
||||
@@ -193,7 +285,6 @@ def load_all_weights(checkpoint_dir, num_layers):
|
||||
|
||||
layer_weights = {}
|
||||
global_weights = {}
|
||||
|
||||
print("Assigning to GPUs...")
|
||||
for key, tensor in all_weights.items():
|
||||
if key.startswith("model.layers."):
|
||||
@@ -213,7 +304,6 @@ def load_all_weights(checkpoint_dir, num_layers):
|
||||
for gpu in range(NUM_GPUS):
|
||||
alloc = torch.cuda.memory_allocated(gpu) / 1e9
|
||||
print(f" GPU {gpu}: {alloc:.1f}GB")
|
||||
|
||||
return layer_weights, global_weights
|
||||
|
||||
|
||||
@@ -221,11 +311,9 @@ def load_all_weights(checkpoint_dir, num_layers):
|
||||
# Single layer forward
|
||||
# =====================================================================
|
||||
|
||||
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
|
||||
"""Forward one layer with mHC.
|
||||
|
||||
X_l: (1, n_hc, H) BF16 → (1, n_hc, H) BF16
|
||||
"""
|
||||
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc,
|
||||
kv_cache, token_id, decode_pos):
|
||||
"""Forward one layer with mHC + MoE + KV cache."""
|
||||
device = X_l.device
|
||||
H = cfg["hidden_size"]
|
||||
n_h = cfg["num_attention_heads"]
|
||||
@@ -234,14 +322,13 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
|
||||
o_rank = cfg["o_lora_rank"]
|
||||
o_groups = cfg["o_groups"]
|
||||
n_hc = 4
|
||||
|
||||
pre = f"model.layers.{li}.self_attn"
|
||||
T = X_l.shape[0]
|
||||
heads_per_group = n_h // o_groups
|
||||
group_input_dim = heads_per_group * hd
|
||||
|
||||
# ==== mHC pre_block (attention) ====
|
||||
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) BF16
|
||||
x_in, attn_ctx = attn_mhc.pre_block(X_l)
|
||||
|
||||
# ==== Q projection ====
|
||||
c_Q = nvfp4_linear(x_in, w[f"{pre}.q_a_proj.weight"],
|
||||
@@ -257,18 +344,29 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
|
||||
w[f"{pre}.kv_proj.weight_scale_2"])
|
||||
|
||||
# ==== Reshape for attention ====
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
|
||||
k = kv.reshape(T, 1, hd).permute(1, 0, 2)
|
||||
v = k.clone()
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
|
||||
k_new = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd)
|
||||
v_new = k_new.clone()
|
||||
|
||||
# ==== KV cache: append new K,V ====
|
||||
kv_cache.append(li, k_new, v_new)
|
||||
k_full, v_full = kv_cache.get(li) # (1, seq_len, hd)
|
||||
seq_len = k_full.shape[1]
|
||||
|
||||
# ==== RoPE ====
|
||||
pos = torch.tensor([0], dtype=torch.long, device=device)
|
||||
q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd)
|
||||
k = apply_rope(k, pos, rope_cos, rope_sin, rd)
|
||||
# Apply RoPE to Q (at current position)
|
||||
q_pos = torch.tensor([decode_pos], dtype=torch.long, device=device)
|
||||
q_heads = apply_rope(q_heads, q_pos, rope_cos, rope_sin, rd)
|
||||
|
||||
# Apply RoPE to K (at each position in the cache)
|
||||
k_positions = torch.arange(seq_len, dtype=torch.long, device=device)
|
||||
k_full_3d = k_full.permute(1, 0, 2) # (seq_len, 1, hd) for RoPE
|
||||
k_full_3d = apply_rope(k_full_3d, k_positions, rope_cos, rope_sin, rd)
|
||||
k_full = k_full_3d.permute(1, 0, 2) # (1, seq_len, hd) — RoPE'd
|
||||
|
||||
# ==== FMHA ====
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
attn_out = dsv4_attention(q_heads, k, v)
|
||||
attn_out = dsv4_attention(q_heads, k_full, v_full)
|
||||
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd)
|
||||
|
||||
# ==== Output projection ====
|
||||
@@ -283,28 +381,13 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
|
||||
w[f"{pre}.o_b_proj.weight_scale_2"])
|
||||
|
||||
# ==== mHC post_block (attention) ====
|
||||
X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
|
||||
X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx)
|
||||
|
||||
# ==== mHC pre_block (FFN) ====
|
||||
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_l)
|
||||
|
||||
# ==== FFN: shared expert ====
|
||||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||||
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
|
||||
F_ffn = torch.zeros_like(x_ffn)
|
||||
if se_gate_w is not None:
|
||||
gate = nvfp4_linear(x_ffn, se_gate_w,
|
||||
w[f"{se_pre}.gate_proj.weight_scale"],
|
||||
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||||
up = nvfp4_linear(x_ffn, w[f"{se_pre}.up_proj.weight"],
|
||||
w[f"{se_pre}.up_proj.weight_scale"],
|
||||
w[f"{se_pre}.up_proj.weight_scale_2"])
|
||||
F_ffn = nvfp4_linear(
|
||||
torch.nn.functional.silu(gate) * up,
|
||||
w[f"{se_pre}.down_proj.weight"],
|
||||
w[f"{se_pre}.down_proj.weight_scale"],
|
||||
w[f"{se_pre}.down_proj.weight_scale_2"],
|
||||
)
|
||||
# ==== MoE + shared expert ====
|
||||
F_ffn = moe_forward(x_ffn, w, li, cfg, token_id)
|
||||
|
||||
# ==== mHC post_block (FFN) ====
|
||||
X_l = ffn_mhc.post_block(X_l, F_ffn, ffn_ctx)
|
||||
@@ -319,7 +402,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
|
||||
def main():
|
||||
t_start = time.time()
|
||||
print("=" * 70)
|
||||
print("DSV4 Single-Shot Inference — 8-GPU with mHC")
|
||||
print("DSV4 Single-Shot Inference — Full Pipeline (mHC+MoE+KV)")
|
||||
print("=" * 70)
|
||||
|
||||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||||
@@ -331,6 +414,7 @@ def main():
|
||||
rd = cfg["qk_rope_head_dim"]
|
||||
n_hc = 4
|
||||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||||
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
|
||||
|
||||
# ==== Phase 1: Load weights ====
|
||||
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
|
||||
@@ -338,7 +422,7 @@ def main():
|
||||
t_loaded = time.time()
|
||||
print(f"Weight loading: {t_loaded - t_start:.1f}s")
|
||||
|
||||
# ==== Build mHC blocks per layer ====
|
||||
# ==== Build mHC blocks ====
|
||||
print("Building mHC blocks...")
|
||||
attn_mhc_blocks = {}
|
||||
ffn_mhc_blocks = {}
|
||||
@@ -346,37 +430,31 @@ def main():
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
|
||||
# Attention mHC
|
||||
attn_fn = layer_weights[li].get(f"model.layers.{li}.attn_hc.fn")
|
||||
attn_base = layer_weights[li].get(f"model.layers.{li}.attn_hc.base")
|
||||
attn_scale = layer_weights[li].get(f"model.layers.{li}.attn_hc.scale")
|
||||
if attn_fn is not None and attn_base is not None and attn_scale is not None:
|
||||
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||||
attn_mhc.load_from_checkpoint(attn_fn, attn_base, attn_scale)
|
||||
attn_mhc_blocks[li] = attn_mhc
|
||||
|
||||
# FFN mHC
|
||||
ffn_fn = layer_weights[li].get(f"model.layers.{li}.ffn_hc.fn")
|
||||
ffn_base = layer_weights[li].get(f"model.layers.{li}.ffn_hc.base")
|
||||
ffn_scale = layer_weights[li].get(f"model.layers.{li}.ffn_hc.scale")
|
||||
if ffn_fn is not None and ffn_base is not None and ffn_scale is not None:
|
||||
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||||
ffn_mhc.load_from_checkpoint(ffn_fn, ffn_base, ffn_scale)
|
||||
ffn_mhc_blocks[li] = ffn_mhc
|
||||
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
|
||||
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
|
||||
fn = layer_weights[li].get(f"{prefix}.fn")
|
||||
base = layer_weights[li].get(f"{prefix}.base")
|
||||
scale = layer_weights[li].get(f"{prefix}.scale")
|
||||
if fn is not None and base is not None and scale is not None:
|
||||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||||
mhc.load_from_checkpoint(fn, base, scale)
|
||||
blocks[li] = mhc
|
||||
|
||||
print(f" attn mHC: {len(attn_mhc_blocks)} layers")
|
||||
print(f" ffn mHC: {len(ffn_mhc_blocks)} layers")
|
||||
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
|
||||
|
||||
# ==== Global weights (gpu0) ====
|
||||
# ==== Global weights ====
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = global_weights.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
|
||||
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
|
||||
final_norm_w = global_weights.get("model.norm.weight")
|
||||
|
||||
# RoPE caches per GPU
|
||||
rope_caches = {g: build_rope_cache(8192, hd, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
|
||||
|
||||
# ==== KV cache (gpu0, moves to target GPU per layer) ====
|
||||
kv_caches = {}
|
||||
for li in range(n_layers):
|
||||
kv_caches[li] = SimpleKVCache(1, hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
|
||||
|
||||
# ==== Phase 2: Compile ====
|
||||
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
@@ -399,9 +477,18 @@ def main():
|
||||
|
||||
generated = input_ids[0].tolist()
|
||||
|
||||
# ==== Prefill: process all prompt tokens at once ====
|
||||
# For now, treat each prompt token as a separate decode step
|
||||
# (proper prefill would use T>1 FMHA, which is wired but not yet
|
||||
# tested with KV cache. One token at a time is correct, just slow.)
|
||||
|
||||
all_tokens = generated.copy() # start with prompt tokens
|
||||
|
||||
for step in range(MAX_NEW_TOKENS):
|
||||
t0 = time.time()
|
||||
tid = torch.tensor([generated[-1]], dtype=torch.long, device='cuda:0')
|
||||
# Current token (last in the sequence)
|
||||
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
|
||||
decode_pos = len(all_tokens) - 1 # absolute position
|
||||
|
||||
# Embed → mHC init state
|
||||
emb = embed(tid) # (1, H) on gpu0
|
||||
@@ -418,25 +505,15 @@ def main():
|
||||
attn_mhc = attn_mhc_blocks.get(li)
|
||||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||||
rc, rs = rope_caches[gpu]
|
||||
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, attn_mhc, ffn_mhc)
|
||||
|
||||
# NaN check every 5 layers or on first NaN
|
||||
if li % 10 == 0 or (X.float().abs().max().item() > 1e4 or torch.isnan(X.float()).any().item()):
|
||||
has_nan = torch.isnan(X.float()).any().item()
|
||||
xmax = X.float().abs().max().item()
|
||||
print(f" L{li}: nan={has_nan}, max_abs={xmax:.2f}")
|
||||
if has_nan:
|
||||
# Debug: check mHC outputs
|
||||
x_out = X[:, 0, :]
|
||||
print(f" L{li} stream0: nan={torch.isnan(x_out.float()).any().item()}, max={x_out.float().abs().max().item():.2f}")
|
||||
break
|
||||
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
|
||||
attn_mhc, ffn_mhc, kv_caches[li], tid, decode_pos)
|
||||
|
||||
# Back to gpu0
|
||||
X = X.to('cuda:0')
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# Read out stream 0 → RMSNorm → lm_head
|
||||
x_out = X[:, 0, :] # (1, H)
|
||||
x_out = X[:, 0, :]
|
||||
if final_norm_w is not None:
|
||||
xf = x_out.float()
|
||||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
@@ -445,17 +522,16 @@ def main():
|
||||
logits = torch.nn.functional.linear(x_out, lm_w)
|
||||
next_id = torch.argmax(logits, dim=-1).item()
|
||||
generated.append(next_id)
|
||||
all_tokens.append(next_id)
|
||||
|
||||
tok_str = tokenizer.decode([next_id])
|
||||
dt = time.time() - t0
|
||||
|
||||
# Check for NaN
|
||||
has_nan = torch.isnan(logits.float()).any().item()
|
||||
logit_range = f"[{logits.float().min().item():.1f}, {logits.float().max().item():.1f}]"
|
||||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits={logit_range} nan={has_nan}")
|
||||
lmin, lmax = logits.float().min().item(), logits.float().max().item()
|
||||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan}")
|
||||
|
||||
if has_nan:
|
||||
print(" NaN detected, stopping")
|
||||
print(" NaN — stopping")
|
||||
break
|
||||
if next_id == tokenizer.eos_token_id:
|
||||
break
|
||||
@@ -465,7 +541,7 @@ def main():
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
print(f"Output: '{out}'")
|
||||
print(f"Total: {total:.1f}s (load: {t_loaded-t_start:.1f}s, compile: {t_compiled-t_loaded:.1f}s, infer: {time.time()-t_compiled:.1f}s)")
|
||||
print(f"Total: {total:.1f}s")
|
||||
print(f"{'='*70}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user