1258 lines
56 KiB
Python
1258 lines
56 KiB
Python
#!/usr/bin/env python3
|
||
"""Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU.
|
||
|
||
This is a reference implementation that exercises the production kernel
|
||
stack end-to-end. It should be usable as ground truth when integrating
|
||
into vLLM or SGLang.
|
||
|
||
Architecture (paper §2):
|
||
X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
|
||
X_mid → mHC.pre_block → RMSNorm → FFN(MoE) → F_ffn → mHC.post_block → X_{l+1}
|
||
|
||
Components exercised:
|
||
- mHC (Manifold-Constrained Hyper-Connections) — proper Sinkhorn-Knopp
|
||
- Low-rank Q projection (q_a → q_b) + KV projection (MQA, 1 KV head)
|
||
- Partial RoPE (last 64 dims, GPT-J interleaved)
|
||
- Production FMHA kernel (6-warp multi-tile, C API + ctypes)
|
||
- Inverse RoPE on attention output (paper §2.3.3)
|
||
- Grouped output projection (wo_a BMM + wo_b NVFP4)
|
||
- Routed MoE (384 experts, top-6, hash + dense routing, SwiGLU clamp)
|
||
- Shared expert (NVFP4 gate/up/down)
|
||
- RMSNorm (pre-norm before each sub-block)
|
||
- KV cache across decode steps
|
||
|
||
Attention type simplification for this single-shot test:
|
||
For short sequences (seq_len ≤ sliding_window=128), ALL attention
|
||
types (CSA/HCA/SWA) reduce to dense attention over the full KV cache.
|
||
CSA's compressed branch and indexer are only needed for long sequences
|
||
where seq_len > sliding_window. HCA is dense over compressed entries,
|
||
but at short sequence lengths, the compressed sequence is trivially
|
||
small. So we use dense MQA attention over the full KV for all layers.
|
||
This is mathematically correct for short sequences and exercises the
|
||
FMHA kernel properly.
|
||
|
||
Usage (on B200):
|
||
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
||
cd /root/dsv4-nvfp4-workspace/kernel
|
||
python3 single_shot_inference.py
|
||
"""
|
||
import os, sys, time, json, math, argparse
|
||
import torch
|
||
from pathlib import Path
|
||
|
||
# =====================================================================
|
||
# Configuration
|
||
# =====================================================================
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser(description='DSV4 Single-Shot Inference')
|
||
p.add_argument('--no-inverse-rope', action='store_true', help='Skip inverse RoPE on attention output')
|
||
p.add_argument('--skip-moe', action='store_true', help='Only use shared expert (skip routed)')
|
||
p.add_argument('--no-thinking', action='store_true', help='Force model to skip thinking (use <|EOT|> instead of thinking tokens)')
|
||
p.add_argument('--max-tokens', type=int, default=512, help='Max new tokens to generate')
|
||
p.add_argument('--prompt', type=str, default=None, help='Override prompt')
|
||
return p.parse_args()
|
||
|
||
_args = parse_args()
|
||
|
||
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||
MAX_NEW_TOKENS = _args.max_tokens
|
||
SYSTEM_PROMPT = "" # Empty system prompt for testing
|
||
PROMPT = _args.prompt or "The capital of France is"
|
||
NUM_GPUS = 8
|
||
SKIP_ROUTED_MOE = _args.skip_moe # If True, only use shared expert (debug)
|
||
INVERSE_ROPE = not _args.no_inverse_rope # If False, skip inverse RoPE on attention output (diagnostic)
|
||
SKIP_MHC = _args.skip_mhc # If True, bypass mHC and use simple residual connections (diagnostic)
|
||
MHC_DIAG = True # If True, print per-layer mHC diagnostics (B_l row/col sums, C_l values)
|
||
# When True: applies inverse RoPE at query position → converts absolute→relative
|
||
# When False: leaves relative position encoding intact for output projection
|
||
# DSV4 partial RoPE only affects last 64/512 dims; first 448 are always un-RoPE'd
|
||
print(f"Config: INVERSE_ROPE={INVERSE_ROPE}, SKIP_ROUTED_MOE={SKIP_ROUTED_MOE}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}")
|
||
|
||
# =====================================================================
|
||
# NVFP4 dequantization — matches checkpoint format exactly
|
||
# =====================================================================
|
||
|
||
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) # E2M1 magnitudes
|
||
|
||
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
|
||
"""Dequantize NVFP4 weight to BF16.
|
||
|
||
weight: (out_dim, in_dim//2) uint8 — 2 FP4 values per byte
|
||
weight_scale: (out_dim, in_dim//16) E4M3 — per-16-element block scale
|
||
weight_scale_2: (out_dim, 1) float32 — per-row global scale
|
||
"""
|
||
out_dim = weight.shape[0]
|
||
in_packed = weight.shape[1]
|
||
in_features = in_packed * 2
|
||
low = (weight & 0x0F).to(torch.int8)
|
||
high = (weight >> 4).to(torch.int8)
|
||
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
|
||
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
|
||
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
||
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
||
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
||
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
||
scale_f = weight_scale.float() * weight_scale_2.float()
|
||
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
||
return (w_f * scale_expanded).bfloat16()
|
||
|
||
|
||
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
||
"""BF16 linear with NVFP4 dequant."""
|
||
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
|
||
return torch.nn.functional.linear(x, w)
|
||
|
||
|
||
# =====================================================================
|
||
# RMSNorm — matches dsv4/layers/norm.py
|
||
# =====================================================================
|
||
|
||
class RMSNorm:
|
||
def __init__(self, hidden_size, eps=1e-6, device='cuda:0'):
|
||
self.eps = eps
|
||
self.weight = torch.ones(hidden_size, dtype=torch.float32, device=device)
|
||
|
||
def forward(self, x):
|
||
"""x: (T, H) BF16 → (T, H) BF16"""
|
||
x_f = x.float()
|
||
rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
||
return (x_f * rms * self.weight).to(torch.bfloat16)
|
||
|
||
|
||
# =====================================================================
|
||
# mHC — proper Sinkhorn-Knopp implementation
|
||
# =====================================================================
|
||
|
||
class mHCBlock:
|
||
"""Wrapper around dsv4.layers.mhc.mHCLayer for single-shot inference.
|
||
|
||
Uses the production mHCLayer implementation with proper Sinkhorn-Knopp.
|
||
"""
|
||
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'):
|
||
from dsv4.layers.mhc import mHCLayer
|
||
self._impl = mHCLayer(
|
||
hidden_dim=hidden_dim, n_hc=n_hc,
|
||
t_max_sinkhorn=sinkhorn_iters,
|
||
device=device, dtype=torch.bfloat16)
|
||
self.device = device
|
||
self.n_hc = n_hc
|
||
self.hidden_dim = hidden_dim
|
||
|
||
def load_from_checkpoint(self, fn, base, scale):
|
||
"""Load from checkpoint tensors.
|
||
|
||
Checkpoint layout (verified against HuggingFace DeepseekV4HyperConnection):
|
||
fn: (24, 28672) — rows ordered [pre(4), post(4), comb(16)]
|
||
base: (24,) — ordered [pre(4), post(4), comb(16)]
|
||
scale: (3,) — [alpha_pre, alpha_post, alpha_comb]
|
||
|
||
The HuggingFace model does:
|
||
pre_w, post_w, comb_w = F.linear(flat, fn).split([4, 4, 16])
|
||
pre_b, post_b, comb_b = base.split([4, 4, 16])
|
||
pre_scale, post_scale, comb_scale = scale.unbind(0)
|
||
"""
|
||
n = self.n_hc
|
||
dev = self.device
|
||
|
||
# fn rows: [pre(4), post(4), comb(16)] — matches HuggingFace
|
||
W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() # fn[0:4]
|
||
W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous() # fn[4:8]
|
||
W_comb = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous() # fn[8:24]
|
||
|
||
# base: [S_pre(4), S_post(4), S_comb(16)] — same ordering as fn
|
||
S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous() # base[0:4]
|
||
S_post = base[n:2*n].reshape(n, 1).to(device=dev, dtype=torch.bfloat16).contiguous() # base[4:8]
|
||
S_comb = base[2*n:].reshape(n, n).to(device=dev, dtype=torch.bfloat16).contiguous() # base[8:24]
|
||
|
||
# scale: [alpha_pre, alpha_post, alpha_comb]
|
||
alpha_pre = scale[0].item()
|
||
alpha_post = scale[1].item()
|
||
alpha_comb = scale[2].item()
|
||
|
||
self._impl.load_weights(
|
||
W_pre=W_pre, W_post=W_post, W_comb=W_comb,
|
||
S_pre=S_pre, S_post=S_post, S_comb=S_comb,
|
||
alpha_pre=alpha_pre, alpha_post=alpha_post, alpha_comb=alpha_comb)
|
||
|
||
@staticmethod
|
||
def init_state(embeddings, n_hc=4):
|
||
from dsv4.layers.mhc import mHCLayer
|
||
return mHCLayer.init_state(embeddings, n_hc)
|
||
|
||
def pre_block(self, X_l):
|
||
return self._impl.pre_block(X_l)
|
||
|
||
def post_block(self, X_l, F_out, ctx):
|
||
return self._impl.post_block(X_l, F_out, ctx)
|
||
|
||
|
||
# =====================================================================
|
||
# RoPE — partial, GPT-J interleaved, last rope_dim dims
|
||
# =====================================================================
|
||
|
||
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0,
|
||
rope_type="default", rope_factor=1.0,
|
||
original_max_pos=4096, beta_fast=32, beta_slow=1):
|
||
"""Build cos/sin caches for partial RoPE.
|
||
|
||
CRITICAL: FP32, not BF16! BF16 quantization destroys cos²+sin²=1
|
||
identity needed for inverse RoPE. BF16 cos²+sin² can be 0.996,
|
||
causing ~3% round-trip error that accumulates across 61 layers.
|
||
|
||
Supports YaRN (Yet another RoPE extensioN) scaling for long context.
|
||
The DSV4 Pro model uses rope_type='yarn' with factor=16.
|
||
|
||
Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) FP32
|
||
"""
|
||
half = rope_dim // 2
|
||
# Base frequencies: 1 / theta^(2i/d)
|
||
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
||
|
||
if rope_type == "yarn" and rope_factor > 1.0:
|
||
# YaRN frequency scaling
|
||
# Compute wavelength thresholds
|
||
low_freq_wavelen = original_max_pos / (beta_fast * 2.0) # High-freq cutoff
|
||
high_freq_wavelen = original_max_pos / (beta_slow * 2.0) # Low-freq cutoff
|
||
|
||
new_freqs = []
|
||
for freq in freqs:
|
||
wavelen = 2 * math.pi / freq
|
||
if wavelen < low_freq_wavelen:
|
||
# High frequency: no scaling
|
||
new_freqs.append(freq)
|
||
elif wavelen > high_freq_wavelen:
|
||
# Low frequency: scale by 1/factor
|
||
new_freqs.append(freq / rope_factor)
|
||
else:
|
||
# Medium frequency: smooth interpolation
|
||
smooth = (original_max_pos / (wavelen * beta_slow) - rope_factor) / (
|
||
rope_factor * (beta_fast / beta_slow - 1)
|
||
)
|
||
new_freqs.append((1 - smooth) * freq / rope_factor + smooth * freq)
|
||
freqs = torch.tensor(new_freqs, dtype=torch.float32)
|
||
|
||
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
||
return torch.cos(angles).to(device), torch.sin(angles).to(device)
|
||
|
||
|
||
def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||
"""Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head.
|
||
Computes in FP32 for numerical stability (inverse RoPE requires cos²+sin²=1)."""
|
||
T, n_h, hd = x.shape
|
||
nope = hd - rope_dim
|
||
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32
|
||
sin = sin_cache[positions].unsqueeze(1)
|
||
x_rope = x[:, :, nope:].float() # FP32 for accurate rotation
|
||
x_even = x_rope[..., 0::2]
|
||
x_odd = x_rope[..., 1::2]
|
||
rot_even = x_even * cos - x_odd * sin
|
||
rot_odd = x_even * sin + x_odd * cos
|
||
result = x.clone()
|
||
rope_out = torch.empty_like(x_rope)
|
||
rope_out[..., 0::2] = rot_even
|
||
rope_out[..., 1::2] = rot_odd
|
||
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
||
return result
|
||
|
||
|
||
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||
"""Apply inverse RoPE (conjugate rotation) to attention output.
|
||
Computes in FP32 for numerical stability."""
|
||
T, n_h, hd = o.shape
|
||
nope = hd - rope_dim
|
||
cos = cos_cache[positions].unsqueeze(1)
|
||
sin = sin_cache[positions].unsqueeze(1)
|
||
o_rope = o[:, :, nope:].float()
|
||
o_even = o_rope[..., 0::2]
|
||
o_odd = o_rope[..., 1::2]
|
||
inv_even = o_even * cos + o_odd * sin
|
||
inv_odd = -o_even * sin + o_odd * cos
|
||
result = o.clone()
|
||
rope_out = torch.empty_like(o_rope)
|
||
rope_out[..., 0::2] = inv_even
|
||
rope_out[..., 1::2] = inv_odd
|
||
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
||
return result
|
||
|
||
class SimpleKVCache:
|
||
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
|
||
MQA: 1 KV head, so cache is (1, seq_len, hd) per layer."""
|
||
def __init__(self, head_dim, max_seq=8192, device='cuda:0'):
|
||
self.hd = head_dim
|
||
self.max_seq = max_seq
|
||
self.device = device
|
||
self.k = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device)
|
||
self.v = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device)
|
||
self.len = 0
|
||
|
||
def append(self, k_new, v_new):
|
||
"""Append K,V. k_new: (1, T, hd), v_new: (1, T, hd)."""
|
||
T = k_new.shape[1]
|
||
self.k[0, self.len:self.len + T] = k_new[0]
|
||
self.v[0, self.len:self.len + T] = v_new[0]
|
||
self.len += T
|
||
|
||
def get(self):
|
||
"""Get K,V up to current length. Returns (1, seq_len, hd) each."""
|
||
return self.k[:, :self.len], self.v[:, :self.len]
|
||
|
||
|
||
# =====================================================================
|
||
# Weight loading — streams safetensors shards, distributes to 8 GPUs
|
||
# =====================================================================
|
||
|
||
def load_weights_to_cpu(checkpoint_dir):
|
||
"""Load all weights from checkpoint to CPU memory.
|
||
|
||
Weights stay on CPU; we move per-layer to GPU on demand during inference.
|
||
This avoids OOM from 285K GPU allocations and allows streaming.
|
||
|
||
Returns:
|
||
all_weights: dict[key] → tensor on CPU
|
||
"""
|
||
from safetensors.torch import load_file
|
||
cdir = Path(checkpoint_dir)
|
||
index_path = cdir / "model.safetensors.index.json"
|
||
weight_map = {}
|
||
if index_path.exists():
|
||
with open(index_path) as f:
|
||
weight_map = json.load(f).get("weight_map", {})
|
||
shard_names = set(weight_map.values()) if weight_map else {
|
||
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
||
}
|
||
print(f"Loading {len(shard_names)} shards to CPU...")
|
||
all_weights = {}
|
||
loaded = 0
|
||
for shard_name in sorted(shard_names):
|
||
if not (cdir / shard_name).exists():
|
||
continue
|
||
data = load_file(str(cdir / shard_name))
|
||
all_weights.update(data)
|
||
loaded += 1
|
||
if loaded % 20 == 0:
|
||
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
|
||
print(f" Done: {len(all_weights)} tensors on CPU")
|
||
return all_weights
|
||
|
||
|
||
def get_layer_weights(all_weights, li, device):
|
||
"""Get weights for layer li, moved to target device.
|
||
|
||
Returns dict of key→tensor on device. Filters by model.layers.{li} prefix.
|
||
"""
|
||
prefix = f"model.layers.{li}."
|
||
w = {}
|
||
for key in all_weights:
|
||
if key.startswith(prefix):
|
||
w[key] = all_weights[key].to(device=device, non_blocking=True)
|
||
return w
|
||
|
||
|
||
def cache_all_layer_weights(all_weights, n_layers, devices):
|
||
"""Pre-load ALL layer weights to their target GPUs.
|
||
|
||
This avoids the per-token CPU→GPU transfer bottleneck. Each layer's
|
||
weights stay on its target GPU for the entire inference run.
|
||
"""
|
||
print(f" Caching layer weights to GPUs...")
|
||
cached = {}
|
||
for li in range(n_layers):
|
||
gpu = li % len(devices)
|
||
dev = devices[gpu]
|
||
cached[li] = get_layer_weights(all_weights, li, dev)
|
||
if (li + 1) % 10 == 0:
|
||
print(f" {li+1}/{n_layers} layers cached")
|
||
print(f" All {n_layers} layers cached to GPUs")
|
||
return cached
|
||
|
||
|
||
# =====================================================================
|
||
# Single layer forward
|
||
# =====================================================================
|
||
|
||
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||
attn_mhc, ffn_mhc, attn_norm, ffn_norm,
|
||
kv_cache, token_id, positions):
|
||
"""Forward one layer with mHC + Attention + FFN.
|
||
|
||
Architecture (paper §2):
|
||
X_l → mHC.pre_block(attn) → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
|
||
X_mid → mHC.pre_block(ffn) → RMSNorm → MoE → F_ffn → mHC.post_block → X_{l+1}
|
||
|
||
X_l: (T, n_hc, H) BF16 — mHC residual state
|
||
Returns: X_next (T, n_hc, H) BF16
|
||
"""
|
||
device = X_l.device
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
||
o_rank = cfg.get("output_group_dim", 1024)
|
||
o_groups = cfg.get("num_output_groups", 16)
|
||
n_hc = 4
|
||
pre = f"model.layers.{li}.self_attn"
|
||
T = X_l.shape[0]
|
||
heads_per_group = n_h // o_groups
|
||
group_input_dim = heads_per_group * hd
|
||
|
||
# ==================================================================
|
||
# ATTENTION SUB-BLOCK
|
||
# ==================================================================
|
||
|
||
if SKIP_MHC:
|
||
# Simple residual: skip mHC, use direct input
|
||
x_in = X_l[:, 0, :] # Just take stream 0
|
||
attn_ctx = None
|
||
else:
|
||
# -- mHC pre_block (attention) --
|
||
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H)
|
||
if MHC_DIAG and attn_ctx is not None: # mHC diagnostics
|
||
B_l, C_l = attn_ctx.B_l, attn_ctx.C_l
|
||
print(f" L{li} pre_attn: |X_l|={X_l.abs().max().item():.2f} |x_in|={x_in.abs().max().item():.2f}", flush=True)
|
||
|
||
# -- RMSNorm (pre-norm before attention) --
|
||
x_normed = attn_norm.forward(x_in) # (T, H) BF16
|
||
|
||
# -- Q projection: q_a (low-rank down) → q_a_norm → q_b (low-rank up) --
|
||
c_Q = nvfp4_linear(x_normed,
|
||
w[f"{pre}.q_a_proj.weight"],
|
||
w[f"{pre}.q_a_proj.weight_scale"],
|
||
w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc)
|
||
# Q norm (RMSNorm after q_a, before q_b)
|
||
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
|
||
if q_norm_w is not None:
|
||
c_Q_f = c_Q.float()
|
||
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
|
||
q = nvfp4_linear(c_Q,
|
||
w[f"{pre}.q_b_proj.weight"],
|
||
w[f"{pre}.q_b_proj.weight_scale"],
|
||
w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd)
|
||
|
||
# q_b_norm — unweighted RMSNorm after q_b_proj (paper §2.3.1)
|
||
# This is critical: normalizes Q before attention, preventing score collapse.
|
||
# No learnable parameters — just q / RMS(q).
|
||
q_f = q.float()
|
||
q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
q = (q_f * q_rms).bfloat16()
|
||
|
||
# -- KV projection (MQA: 1 KV head) + KV norm --
|
||
kv = nvfp4_linear(x_normed,
|
||
w[f"{pre}.kv_proj.weight"],
|
||
w[f"{pre}.kv_proj.weight_scale"],
|
||
w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd)
|
||
# KV norm (RMSNorm after kv_proj)
|
||
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
|
||
if kv_norm_w is not None:
|
||
kv_f = kv.float()
|
||
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
|
||
|
||
# -- Reshape for attention --
|
||
q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd)
|
||
kv_new = kv.reshape(T, 1, hd) # (T, 1, hd) — 1 KV head
|
||
|
||
# Diagnostic: Q/KV norms
|
||
if MHC_DIAG and li < 3:
|
||
print(f" L{li} Q: |q|={q_heads.abs().max().item():.2f} mean={q_heads.float().abs().mean().item():.4f}")
|
||
print(f" L{li} KV: |kv|={kv_new.abs().max().item():.2f} mean={kv_new.float().abs().mean().item():.4f}")
|
||
|
||
# -- Apply RoPE to Q (at current positions) --
|
||
positions_dev = positions.to(device)
|
||
q_heads = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- Apply RoPE to KV (at current positions) BEFORE caching --
|
||
# DSV4 convention: RoPE applied to KV before writing to cache.
|
||
# K = V in DSV4 MQA (same projection, same RoPE'd tensor).
|
||
kv_new = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- KV cache: append RoPE'd KV (K=V) --
|
||
k_new = kv_new # (T, 1, hd) — RoPE'd
|
||
v_new = kv_new # K = V in DSV4 MQA
|
||
kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd)
|
||
|
||
# -- Get full KV from cache (already RoPE'd) --
|
||
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
|
||
seq_len = k_full.shape[1]
|
||
|
||
# -- Attention: SDPA for short seqs (avoids FMHA padding bug), FMHA for long --
|
||
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||
scale = 1.0 / math.sqrt(hd)
|
||
|
||
# FMHA pads N to next multiple of 128. For N<<128, padded zero-K entries
|
||
# contribute exp(0)=1 to softmax, diluting real attention weights by ~128/N.
|
||
# Use SDPA for short sequences where padding dominates.
|
||
if seq_len < 120:
|
||
k_expanded = k_full.expand(n_h, -1, -1).contiguous()
|
||
v_expanded = v_full.expand(n_h, -1, -1).contiguous()
|
||
# Attention: compute raw scores, add sinks as logit bias, softmax, multiply by V
|
||
# (paper D5c, matching HuggingFace reference implementation)
|
||
# Sinks are added as a logit column, softmaxed together, then DROPPED
|
||
# before V multiplication — NOT as a dummy KV entry.
|
||
sink_key = f"{pre}.sinks"
|
||
scores_raw = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len)
|
||
if sink_key in w and seq_len > 0:
|
||
sinks = w[sink_key].to(device=device) # (n_h,) BF16
|
||
# sinks: (n_h,) → reshape to (n_h, 1, 1) for broadcasting with (n_h, T, seq_len)
|
||
sink_logits = sinks.float().reshape(n_h, 1, 1).expand(-1, T, 1)
|
||
combined_logits = torch.cat([scores_raw, sink_logits], dim=-1) # (n_h, T, seq_len+1)
|
||
# Stable softmax
|
||
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
|
||
probs = torch.softmax(combined_logits.float(), dim=-1).to(torch.bfloat16)
|
||
attn_weights = probs[..., :-1] # Drop sink column (n_h, T, seq_len)
|
||
else:
|
||
attn_weights = torch.softmax(scores_raw.float(), dim=-1).to(torch.bfloat16)
|
||
attn_out = torch.matmul(attn_weights, v_expanded) # (n_h, T, hd)
|
||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||
# Diagnostic: check attention entropy (how spread out the attention is)
|
||
if MHC_DIAG and li < 3:
|
||
with torch.no_grad():
|
||
scores = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len)
|
||
weights = torch.softmax(scores.float(), dim=-1) # (n_h, 1, seq_len)
|
||
# For head 0: what positions get the most weight?
|
||
w0 = weights[0, 0] # (seq_len,)
|
||
top3_pos = torch.topk(w0, min(3, seq_len))
|
||
entropy = -(w0 * (w0 + 1e-10).log()).sum().item()
|
||
print(f" L{li} attn: seq_len={seq_len} entropy={entropy:.2f} top3_pos={top3_pos.indices.tolist()} top3_w={top3_pos.values.tolist()}")
|
||
else:
|
||
# Use FMHA kernel for longer sequences (padding effect is negligible)
|
||
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
|
||
q_4d = q_input.unsqueeze(0).contiguous()
|
||
k_4d = k_full.unsqueeze(0).contiguous()
|
||
v_4d = v_full.unsqueeze(0).transpose(-1, -2).contiguous()
|
||
o_4d, lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
|
||
attn_out = o_4d.squeeze(0).permute(1, 0, 2)
|
||
# Sink correction
|
||
sink_key = f"{pre}.sinks"
|
||
if sink_key in w and seq_len > 0:
|
||
sinks = w[sink_key].to(device=device)
|
||
lse_2d = lse.squeeze(0).t()
|
||
sink_exp = torch.exp(sinks.float())
|
||
attn_exp = torch.exp(lse_2d.float())
|
||
correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10)
|
||
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16()
|
||
attn_out = attn_out.bfloat16()
|
||
|
||
|
||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||
# DSV4 uses K=V in MQA; both get RoPE'd. Inverse RoPE on the output
|
||
# at query position q converts: R(q)⁻¹ Σ softmax(R(q)Q·R(p)K) R(p)V
|
||
# For single KV entry at p: R(p-q)V (relative position encoding)
|
||
# This only affects the last 64 dims (partial RoPE); first 448 unchanged.
|
||
# The relative encoding in those 64 dims may be INTENTIONAL — the
|
||
# output projection can use it for position-dependent computation.
|
||
# Test both modes via INVERSE_ROPE flag.
|
||
if INVERSE_ROPE:
|
||
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) --
|
||
# wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM
|
||
attn_flat = attn_out.reshape(T, n_h * hd) # (T, n_h * hd)
|
||
attn_grouped = attn_flat.reshape(T, o_groups, heads_per_group * hd) # (T, groups, group_dim)
|
||
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (n_groups * o_rank, group_input_dim) BF16
|
||
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (groups, o_rank, group_dim)
|
||
attn_for_bmm = attn_grouped.permute(1, 0, 2) # (groups, T, group_dim)
|
||
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (groups, T, o_rank)
|
||
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) # (T, groups*o_rank)
|
||
|
||
F_attn = nvfp4_linear(grouped_flat,
|
||
w[f"{pre}.o_b_proj.weight"],
|
||
w[f"{pre}.o_b_proj.weight_scale"],
|
||
w[f"{pre}.o_b_proj.weight_scale_2"]) # (T, H)
|
||
|
||
if SKIP_MHC:
|
||
X_mid = X_l[:, 0, :].unsqueeze(1).expand(-1, 4, -1) + F_attn.unsqueeze(1) * 0.1
|
||
else:
|
||
# -- mHC post_block (attention) --
|
||
X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
|
||
# Diagnostic: check mHC is stabilizing the residual
|
||
if MHC_DIAG and attn_ctx is not None: # mHC diagnostics
|
||
B_l, C_l = attn_ctx.B_l, attn_ctx.C_l
|
||
print(f" L{li} attn: |X_l|={X_l.abs().max().item():.2f} |F_attn|={F_attn.abs().max().item():.2f} |B|={B_l.abs().max().item():.4f} |C|={C_l.abs().max().item():.4f} |X_mid|={X_mid.abs().max().item():.2f}")
|
||
# Check B_l is doubly stochastic (rows sum to 1.0)
|
||
B_row_sums = B_l.sum(dim=-1) # (T, n_hc)
|
||
B_col_sums = B_l.sum(dim=-2) # (T, n_hc)
|
||
print(f" B row_sums={B_row_sums[0].tolist()} col_sums={B_col_sums[0].tolist()}")
|
||
print(f" C_l={C_l[0].tolist()}")
|
||
|
||
# ==================================================================
|
||
# FFN SUB-BLOCK
|
||
# ==================================================================
|
||
|
||
if SKIP_MHC:
|
||
x_ffn = X_mid[:, 0, :] # Just take stream 0
|
||
ffn_ctx = None
|
||
else:
|
||
# -- mHC pre_block (FFN) --
|
||
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid) # (T, H)
|
||
|
||
# -- RMSNorm (pre-norm before FFN) --
|
||
x_ffn_normed = ffn_norm.forward(x_ffn) # (T, H) BF16
|
||
|
||
# -- MoE + shared expert --
|
||
F_ffn = moe_forward(x_ffn_normed, w, li, cfg, token_id, device)
|
||
|
||
if SKIP_MHC:
|
||
X_next = X_mid + F_ffn.unsqueeze(1) * 0.1
|
||
else:
|
||
# -- mHC post_block (FFN) --
|
||
X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H)
|
||
if MHC_DIAG and ffn_ctx is not None: # ffn mHC diagnostics
|
||
B_l_ffn, C_l_ffn = ffn_ctx.B_l, ffn_ctx.C_l
|
||
print(f" L{li} ffn: |X_mid|={X_mid.abs().max().item():.2f} |F_ffn|={F_ffn.abs().max().item():.2f} |B|={B_l_ffn.abs().max().item():.4f} |C|={C_l_ffn.abs().max().item():.4f} |X_next|={X_next.abs().max().item():.2f}", flush=True)
|
||
|
||
return X_next
|
||
|
||
|
||
# =====================================================================
|
||
# MoE forward — hash + dense routing, SwiGLU with clamping
|
||
# =====================================================================
|
||
|
||
def moe_forward(x, w, li, cfg, token_id, device):
|
||
"""Run routed MoE + shared expert.
|
||
|
||
x: (T, H) BF16 — post-RMSNorm FFN input
|
||
Returns: (T, H) BF16
|
||
"""
|
||
H = cfg["hidden_size"]
|
||
n_experts = cfg["n_routed_experts"]
|
||
top_k = cfg.get("num_experts_per_tok", 6)
|
||
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
|
||
swiglu_limit = cfg.get("swiglu_limit", 10.0)
|
||
mlp_inter = cfg["moe_intermediate_size"]
|
||
# ---- Routing ----
|
||
# Layers 0-2: hash routing (tid2eid lookup)
|
||
# Layers 3+: noaux_tc (sqrt(softplus) scoring + e_score_correction_bias for selection only)
|
||
# Config: topk_method='noaux_tc', scoring_func='sqrtsoftplus'
|
||
expert_ids = None
|
||
expert_weights = None
|
||
|
||
tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid"
|
||
e_bias_key = f"model.layers.{li}.mlp.gate.e_score_correction_bias"
|
||
is_hash = tid2eid_key in w and e_bias_key not in w
|
||
|
||
if is_hash:
|
||
# Hash routing: deterministic per-token lookup, uniform weights
|
||
tid2eid = w[tid2eid_key]
|
||
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
|
||
expert_ids = tid2eid[tid] # (top_k,) int64
|
||
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
|
||
else:
|
||
# Dense routing: sqrt(softplus(logits)) scoring
|
||
gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (H, n_experts) BF16
|
||
logits = torch.nn.functional.linear(x, gate_w.bfloat16()) # (T, n_experts)
|
||
# Scoring: sqrt(softplus(logits))
|
||
scores = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
|
||
# e_score_correction_bias: per-expert bias for SELECTION ONLY
|
||
selection_logits = scores.clone()
|
||
if e_bias_key in w:
|
||
selection_logits = selection_logits + w[e_bias_key].float().unsqueeze(0)
|
||
_, indices = selection_logits.topk(top_k, dim=-1) # (T, top_k)
|
||
# Weights from UNBIASED scores (no e_bias)
|
||
expert_weights = torch.gather(scores, -1, indices)
|
||
expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
|
||
# For T=1 decode, squeeze
|
||
if x.shape[0] == 1:
|
||
expert_ids = indices[0]
|
||
expert_weights = expert_weights[0]
|
||
else:
|
||
raise NotImplementedError("Multi-token MoE routing")
|
||
|
||
# ---- Run selected experts ----
|
||
T = x.shape[0]
|
||
expert_outputs = []
|
||
if not SKIP_ROUTED_MOE:
|
||
for i, eid in enumerate(expert_ids):
|
||
eid_int = eid.item()
|
||
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
|
||
|
||
gate = nvfp4_linear(x,
|
||
w[f"{epre}.gate_proj.weight"],
|
||
w[f"{epre}.gate_proj.weight_scale"],
|
||
w[f"{epre}.gate_proj.weight_scale_2"])
|
||
up = nvfp4_linear(x,
|
||
w[f"{epre}.up_proj.weight"],
|
||
w[f"{epre}.up_proj.weight_scale"],
|
||
w[f"{epre}.up_proj.weight_scale_2"])
|
||
|
||
# SwiGLU with clamping (paper §4.2.3)
|
||
silu_out = torch.nn.functional.silu(gate.float())
|
||
if swiglu_limit is not None:
|
||
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
|
||
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
|
||
else:
|
||
up_clamped = up.float()
|
||
hidden = (silu_out * up_clamped).bfloat16()
|
||
|
||
down = nvfp4_linear(hidden,
|
||
w[f"{epre}.down_proj.weight"],
|
||
w[f"{epre}.down_proj.weight_scale"],
|
||
w[f"{epre}.down_proj.weight_scale_2"])
|
||
expert_outputs.append(down)
|
||
|
||
# Weighted combine + scaling
|
||
routed_out = torch.zeros_like(x)
|
||
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
|
||
routed_out = routed_out + (out.float() * wt.item()).bfloat16()
|
||
routed_out = (routed_out.float() * routed_scaling).bfloat16()
|
||
|
||
# ---- Shared expert ----
|
||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||
se_gate_key = f"{se_pre}.gate_proj.weight"
|
||
if se_gate_key in w:
|
||
gate = nvfp4_linear(x,
|
||
w[se_gate_key],
|
||
w[f"{se_pre}.gate_proj.weight_scale"],
|
||
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||
up = nvfp4_linear(x,
|
||
w[f"{se_pre}.up_proj.weight"],
|
||
w[f"{se_pre}.up_proj.weight_scale"],
|
||
w[f"{se_pre}.up_proj.weight_scale_2"])
|
||
silu_out = torch.nn.functional.silu(gate.float())
|
||
if swiglu_limit is not None:
|
||
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
|
||
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
|
||
else:
|
||
up_clamped = up.float()
|
||
hidden = (silu_out * up_clamped).bfloat16()
|
||
shared_out = nvfp4_linear(hidden,
|
||
w[f"{se_pre}.down_proj.weight"],
|
||
w[f"{se_pre}.down_proj.weight_scale"],
|
||
w[f"{se_pre}.down_proj.weight_scale_2"])
|
||
else:
|
||
shared_out = torch.zeros_like(x)
|
||
|
||
return routed_out + shared_out
|
||
|
||
|
||
# =====================================================================
|
||
# Main
|
||
# =====================================================================
|
||
|
||
def main():
|
||
t_start = time.time()
|
||
print("=" * 70)
|
||
print("DSV4 Single-Shot Inference — Full Pipeline (mHC+Attn+MoE)")
|
||
print(" Proper Sinkhorn mHC, RMSNorm, inverse RoPE, production FMHA")
|
||
print("=" * 70)
|
||
|
||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||
cfg = json.load(f)
|
||
n_layers = cfg["num_hidden_layers"]
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
||
n_hc = 4
|
||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
|
||
|
||
# ==== Phase 1: Load weights to CPU ====
|
||
print(f"\n{'='*70}\nPhase 1: Loading weights to CPU\n{'='*70}")
|
||
all_weights = load_weights_to_cpu(CHECKPOINT_DIR)
|
||
t_loaded = time.time()
|
||
print(f"Weight loading: {t_loaded - t_start:.1f}s")
|
||
|
||
# ==== Build mHC blocks + RMSNorms (small weights, keep on GPU) ====
|
||
print("Building mHC blocks and RMSNorms...")
|
||
attn_mhc_blocks = {}
|
||
ffn_mhc_blocks = {}
|
||
attn_norms = {}
|
||
ffn_norms = {}
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
dev = f"cuda:{gpu}"
|
||
|
||
# mHC blocks (small weights: fn (24, 28672) FP32 ≈ 2.6MB each)
|
||
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
|
||
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
|
||
fn_key = f"{prefix}.fn"
|
||
base_key = f"{prefix}.base"
|
||
scale_key = f"{prefix}.scale"
|
||
if fn_key in all_weights and base_key in all_weights and scale_key in all_weights:
|
||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||
mhc.load_from_checkpoint(
|
||
all_weights[fn_key], all_weights[base_key], all_weights[scale_key])
|
||
blocks[li] = mhc
|
||
else:
|
||
print(f" WARNING: no mHC weights for {prefix}, using identity fallback")
|
||
# Fallback: near-identity mHC (small alphas, identity comb)
|
||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||
n = n_hc
|
||
K = n * H
|
||
mhc._impl.W_pre = torch.zeros(n, K, dtype=torch.float32, device=dev)
|
||
mhc._impl.W_post = torch.zeros(n, K, dtype=torch.float32, device=dev)
|
||
mhc._impl.W_comb = torch.zeros(n*n, K, dtype=torch.float32, device=dev)
|
||
mhc._impl.S_pre = torch.zeros(1, n, dtype=torch.bfloat16, device=dev)
|
||
mhc._impl.S_post = torch.ones(n, 1, dtype=torch.bfloat16, device=dev) * 0.5
|
||
mhc._impl.S_comb = torch.eye(n, dtype=torch.bfloat16, device=dev)
|
||
mhc._impl.alpha_pre = torch.tensor(0.01, dtype=torch.float32, device=dev)
|
||
mhc._impl.alpha_post = torch.tensor(0.01, dtype=torch.float32, device=dev)
|
||
mhc._impl.alpha_comb = torch.tensor(0.01, dtype=torch.float32, device=dev)
|
||
blocks[li] = mhc
|
||
|
||
# RMSNorms
|
||
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
||
an_key = f"model.layers.{li}.input_layernorm.weight"
|
||
if an_key in all_weights:
|
||
attn_norm.weight = all_weights[an_key].to(device=dev, dtype=torch.float32)
|
||
attn_norms[li] = attn_norm
|
||
|
||
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
||
fn_key = f"model.layers.{li}.post_attention_layernorm.weight"
|
||
if fn_key in all_weights:
|
||
ffn_norm.weight = all_weights[fn_key].to(device=dev, dtype=torch.float32)
|
||
ffn_norms[li] = ffn_norm
|
||
|
||
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
|
||
|
||
# ==== Global weights (small, keep on gpu0) ====
|
||
torch.cuda.set_device(0)
|
||
embed_w = all_weights.get("model.embed_tokens.weight")
|
||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||
lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||
final_norm_w = all_weights.get("model.norm.weight")
|
||
if final_norm_w is not None:
|
||
final_norm_w = final_norm_w.to('cuda:0')
|
||
# Build RoPE caches with YaRN scaling from model config
|
||
rope_params = cfg.get("rope_parameters", {})
|
||
rope_type = rope_params.get("rope_type", "default")
|
||
rope_factor = rope_params.get("factor", 1.0)
|
||
rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.0))
|
||
original_max_pos = rope_params.get("original_max_position_embeddings", 4096)
|
||
beta_fast = rope_params.get("beta_fast", 32)
|
||
beta_slow = rope_params.get("beta_slow", 1)
|
||
print(f"RoPE: type={rope_type} factor={rope_factor} theta={rope_theta} "
|
||
f"orig_max_pos={original_max_pos} beta_fast={beta_fast} beta_slow={beta_slow}", flush=True)
|
||
rope_caches = {g: build_rope_cache(
|
||
8192, rd, f"cuda:{g}", theta=rope_theta,
|
||
rope_type=rope_type, rope_factor=rope_factor,
|
||
original_max_pos=original_max_pos,
|
||
beta_fast=beta_fast, beta_slow=beta_slow
|
||
) for g in range(NUM_GPUS)}
|
||
|
||
# ==== KV caches (one per layer on its GPU) ====
|
||
kv_caches = {}
|
||
for li in range(n_layers):
|
||
kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
|
||
|
||
# ==== Cache ALL layer weights to GPUs (avoids per-token CPU→GPU transfer) ====
|
||
print(f"\n Caching layer weights to GPUs (one-time transfer)...", flush=True)
|
||
devices = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||
layer_weights = cache_all_layer_weights(all_weights, n_layers, devices)
|
||
print(f" Done. Freeing CPU weights...", flush=True)
|
||
del all_weights
|
||
import gc; gc.collect()
|
||
|
||
# ==== Phase 2: Compile FMHA ====
|
||
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
|
||
from dsv4.kernels.attention.production import dsv4_attention
|
||
torch.cuda.set_device(0)
|
||
dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
||
dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
||
try:
|
||
_ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone())
|
||
print(" FMHA: compiled OK")
|
||
except Exception as e:
|
||
print(f" FMHA error: {e}")
|
||
t_compiled = time.time()
|
||
print(f"Compile: {t_compiled - t_loaded:.1f}s")
|
||
|
||
# ==== Phase 2.5: Minimal E2E test ====
|
||
print(f"\n{'='*70}\nPhase 2.5: Minimal E2E Test (single token 'The')\n{'='*70}")
|
||
from transformers import AutoTokenizer
|
||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||
minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks,
|
||
ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w,
|
||
final_norm_w, tokenizer)
|
||
|
||
# ==== Phase 2.6: Single-layer trace ====
|
||
if True: # Always run the trace
|
||
print(f"\n{'='*70}\nPhase 2.6: Single-Layer Trace (layer 0, first prefill token)\n{'='*70}", flush=True)
|
||
li = 0
|
||
dev = f"cuda:0"
|
||
w = layer_weights[li]
|
||
pre = f"model.layers.{li}.self_attn"
|
||
T_dim = 1
|
||
positions = torch.tensor([0], dtype=torch.long, device=dev)
|
||
rope_cos, rope_sin = rope_caches[0]
|
||
|
||
# Start from the embedding
|
||
tid = torch.tensor([tokenizer.encode("The")[-1]], dtype=torch.long, device=dev)
|
||
emb = embed(tid) # (1, H)
|
||
X = mHCBlock.init_state(emb, 4) # (1, 4, H)
|
||
print(f" X after init_state: |X|={X.abs().max().item():.4f} stream0_mean={X[:,0,:].float().abs().mean().item():.6f}", flush=True)
|
||
|
||
# mHC pre_block
|
||
attn_mhc = attn_mhc_blocks[0]
|
||
x_in, ctx = attn_mhc.pre_block(X)
|
||
print(f" x_in (mHC pre_block): |x_in|={x_in.abs().max().item():.4f} mean={x_in.float().abs().mean().item():.6f}", flush=True)
|
||
B_l = ctx.B_l
|
||
C_l = ctx.C_l
|
||
print(f" B_l row_sums={B_l[0].sum(dim=-1).tolist()}", flush=True)
|
||
print(f" C_l={C_l[0].tolist()}", flush=True)
|
||
|
||
# RMSNorm
|
||
a_norm = attn_norms[0]
|
||
x_normed = a_norm.forward(x_in)
|
||
print(f" x_normed: |x|={x_normed.abs().max().item():.4f} mean={x_normed.float().abs().mean().item():.6f}", flush=True)
|
||
|
||
# Q projection
|
||
c_Q = nvfp4_linear(x_normed,
|
||
w[f"{pre}.q_a_proj.weight"],
|
||
w[f"{pre}.q_a_proj.weight_scale"],
|
||
w[f"{pre}.q_a_proj.weight_scale_2"])
|
||
print(f" c_Q (q_a_proj): |c_Q|={c_Q.abs().max().item():.4f} mean={c_Q.float().abs().mean().item():.6f}", flush=True)
|
||
|
||
# q_a_norm
|
||
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
|
||
if q_norm_w is not None:
|
||
c_Q_f = c_Q.float()
|
||
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
|
||
print(f" c_Q after q_a_norm: |c_Q|={c_Q.abs().max().item():.4f}", flush=True)
|
||
|
||
q = nvfp4_linear(c_Q,
|
||
w[f"{pre}.q_b_proj.weight"],
|
||
w[f"{pre}.q_b_proj.weight_scale"],
|
||
w[f"{pre}.q_b_proj.weight_scale_2"])
|
||
q_heads = q.reshape(T_dim, n_h, hd)
|
||
print(f" q_heads: |q|={q_heads.abs().max().item():.4f} mean={q_heads.float().abs().mean().item():.6f}", flush=True)
|
||
|
||
# KV projection
|
||
kv = nvfp4_linear(x_normed,
|
||
w[f"{pre}.kv_proj.weight"],
|
||
w[f"{pre}.kv_proj.weight_scale"],
|
||
w[f"{pre}.kv_proj.weight_scale_2"])
|
||
print(f" kv (kv_proj): |kv|={kv.abs().max().item():.4f} mean={kv.float().abs().mean().item():.6f}", flush=True)
|
||
|
||
# kv_norm
|
||
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
|
||
if kv_norm_w is not None:
|
||
kv_f = kv.float()
|
||
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
|
||
print(f" kv after kv_norm: |kv|={kv.abs().max().item():.4f}", flush=True)
|
||
|
||
kv_new = kv.reshape(T_dim, 1, hd) # (1, 1, hd)
|
||
print(f" kv_new shape: {kv_new.shape}", flush=True)
|
||
|
||
# Apply RoPE
|
||
q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd)
|
||
kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd)
|
||
print(f" After RoPE: |q|={q_heads.abs().max().item():.4f} |kv|={kv_new.abs().max().item():.4f}", flush=True)
|
||
|
||
# Self-attention (single token, trivially weight=1.0)
|
||
q_input = q_heads.permute(1, 0, 2) # (n_h, 1, hd)
|
||
k_input = kv_new.permute(1, 0, 2) # (1, 1, hd) -> expand
|
||
k_expanded = k_input.expand(n_h, -1, -1).contiguous()
|
||
v_expanded = k_expanded.clone() # K=V in DSV4 MQA
|
||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||
q_input, k_expanded, v_expanded, scale=1.0/math.sqrt(hd))
|
||
attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd)
|
||
print(f" attn_out: |o|={attn_out.abs().max().item():.4f} mean={attn_out.float().abs().mean().item():.6f}", flush=True)
|
||
|
||
# Inverse RoPE
|
||
if INVERSE_ROPE:
|
||
attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd)
|
||
print(f" After inverse RoPE: |o|={attn_out.abs().max().item():.4f}", flush=True)
|
||
|
||
# Output projection
|
||
o_groups = cfg.get("num_output_groups", 16)
|
||
o_rank = cfg.get("output_group_dim", 1024)
|
||
heads_per_group = n_h // o_groups
|
||
group_input_dim = heads_per_group * hd
|
||
attn_flat = attn_out.reshape(T_dim, n_h * hd)
|
||
attn_grouped = attn_flat.reshape(T_dim, o_groups, heads_per_group * hd)
|
||
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
|
||
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
|
||
attn_for_bmm = attn_grouped.permute(1, 0, 2)
|
||
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
|
||
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T_dim, o_groups * o_rank)
|
||
print(f" grouped_out (wo_a): |o|={grouped_flat.abs().max().item():.4f} mean={grouped_flat.float().abs().mean().item():.6f}", flush=True)
|
||
|
||
F_attn = nvfp4_linear(grouped_flat,
|
||
w[f"{pre}.o_b_proj.weight"],
|
||
w[f"{pre}.o_b_proj.weight_scale"],
|
||
w[f"{pre}.o_b_proj.weight_scale_2"])
|
||
print(f" F_attn (wo_b): |F|={F_attn.abs().max().item():.4f} mean={F_attn.float().abs().mean().item():.6f}", flush=True)
|
||
|
||
# mHC post_block
|
||
X_mid = attn_mhc.post_block(X, F_attn, ctx)
|
||
print(f" X_mid: |X|={X_mid.abs().max().item():.4f} stream0_mean={X_mid[:,0,:].float().abs().mean().item():.6f}", flush=True)
|
||
|
||
print(f" Layer 0 trace complete.", flush=True)
|
||
|
||
# ==== Phase 3: Inference ====
|
||
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
|
||
# DeepSeek V4 chat format: <|begin▁of▁sentence|><|User|>prompt<|Assistant|>
|
||
# For reasoning models: <|User|>prompt<|Assistant|>fithinking...flanswer
|
||
# Special token IDs: <|User|>=128803, <|Assistant|>=128804, <|EOT|>=128805
|
||
# Thinking tokens: fi=128821, fl=128822
|
||
USER_TOKEN = 128803
|
||
ASSISTANT_TOKEN = 128804
|
||
EOT_TOKEN = 128805
|
||
THINK_START = 128821 # fi
|
||
THINK_END = 128822 # fl
|
||
|
||
# Build input with proper DeepSeek chat format
|
||
bos_id = tokenizer.bos_token_id or 0
|
||
# <BOS> <|User|> System prompt \n\n User prompt <|Assistant|>
|
||
input_ids_list = [bos_id, USER_TOKEN]
|
||
input_ids_list += tokenizer.encode(SYSTEM_PROMPT, add_special_tokens=False)
|
||
input_ids_list += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
|
||
input_ids_list.append(ASSISTANT_TOKEN)
|
||
input_ids = torch.tensor([input_ids_list], dtype=torch.long).cuda()
|
||
print(f"DeepSeek chat format. Input: {input_ids.shape[1]} tokens", flush=True)
|
||
print(f"Decoded start: '{tokenizer.decode(input_ids[0][:20])}...'", flush=True)
|
||
print(f"Decoded end: '...{tokenizer.decode(input_ids[0][-5:])}'", flush=True)
|
||
|
||
generated = input_ids[0].tolist()
|
||
|
||
# ==== Prefill: process prompt tokens to fill KV cache ====
|
||
print(f"Prefilling {len(generated)} prompt tokens...", flush=True)
|
||
for prefill_idx, tid_val in enumerate(generated):
|
||
t0 = time.time()
|
||
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
||
positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0')
|
||
emb = embed(tid) # (1, H) on gpu0
|
||
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
|
||
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
dev = f"cuda:{gpu}"
|
||
if X.device != torch.device(dev):
|
||
X = X.to(dev)
|
||
torch.cuda.set_device(gpu)
|
||
|
||
w = layer_weights[li]
|
||
|
||
attn_mhc = attn_mhc_blocks.get(li)
|
||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||
a_norm = attn_norms[li]
|
||
f_norm = ffn_norms[li]
|
||
rc, rs = rope_caches[gpu]
|
||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||
kv_caches[li], tid, positions)
|
||
|
||
X = X.to('cuda:0')
|
||
torch.cuda.set_device(0)
|
||
if prefill_idx % 10 == 0:
|
||
print(f" Token {prefill_idx}/{len(generated)}: {time.time()-t0:.2f}s", flush=True)
|
||
|
||
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
|
||
|
||
# ==== Decode: generate new tokens ====
|
||
print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...")
|
||
all_tokens = generated.copy()
|
||
|
||
for step in range(MAX_NEW_TOKENS):
|
||
t0 = time.time()
|
||
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
|
||
decode_pos = len(all_tokens) - 1
|
||
positions = torch.tensor([decode_pos], dtype=torch.long, device='cuda:0')
|
||
|
||
emb = embed(tid) # (1, H) on gpu0
|
||
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
|
||
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
dev = f"cuda:{gpu}"
|
||
if X.device != torch.device(dev):
|
||
X = X.to(dev)
|
||
torch.cuda.set_device(gpu)
|
||
|
||
w = layer_weights[li]
|
||
|
||
attn_mhc = attn_mhc_blocks.get(li)
|
||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||
a_norm = attn_norms[li]
|
||
f_norm = ffn_norms[li]
|
||
rc, rs = rope_caches[gpu]
|
||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||
kv_caches[li], tid, positions)
|
||
|
||
X = X.to('cuda:0')
|
||
torch.cuda.set_device(0)
|
||
|
||
# Read out stream 0 → RMSNorm → lm_head
|
||
x_out = X[:, 0, :] # (1, H)
|
||
if final_norm_w is not None:
|
||
xf = x_out.float()
|
||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
x_out = (xf * rms * final_norm_w.float()).bfloat16()
|
||
|
||
logits = torch.nn.functional.linear(x_out, lm_w)
|
||
# Top-5 predictions for debugging
|
||
# Top-20 predictions for debugging (includes thinking tokens)
|
||
top20_vals, top20_ids = torch.topk(logits[0], 20)
|
||
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top20_ids[:5], top20_vals[:5])])
|
||
# Check if thinking tokens are in top-20
|
||
thinking_in_top20 = any(tid.item() in [128821, 128822] for tid in top20_ids)
|
||
top20_ids_set = set(top20_ids.tolist())
|
||
next_id = torch.argmax(logits, dim=-1).item()
|
||
generated.append(next_id)
|
||
all_tokens.append(next_id)
|
||
|
||
tok_str = tokenizer.decode([next_id])
|
||
dt = time.time() - t0
|
||
has_nan = torch.isnan(logits.float()).any().item()
|
||
has_inf = torch.isinf(logits.float()).any().item()
|
||
lmin, lmax = logits.float().min().item(), logits.float().max().item()
|
||
x_max = X.abs().max().item()
|
||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) "
|
||
f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} "
|
||
f"|X|={x_max:.3f} top5: {top5_str}", flush=True)
|
||
if thinking_in_top20:
|
||
for tid_t, val_t in zip(top20_ids, top20_vals):
|
||
if tid_t.item() in [128821, 128822]:
|
||
print(f" THINK TOKEN: {tid_t.item()} logit={val_t.item():.3f}", flush=True)
|
||
if step % 5 == 0:
|
||
print(f" Top-20: {[(tokenizer.decode([t.item()]), f'{v.item():.2f}') for t, v in zip(top20_ids, top20_vals)]}", flush=True)
|
||
|
||
if has_nan or has_inf:
|
||
print(" Numerical issue — stopping")
|
||
break
|
||
if next_id == tokenizer.eos_token_id:
|
||
break
|
||
|
||
out = tokenizer.decode(generated, skip_special_tokens=True)
|
||
total = time.time() - t_start
|
||
print(f"\n{'='*70}")
|
||
print(f"Input: '{PROMPT}'")
|
||
print(f"Output: '{out}'")
|
||
print(f"Total: {total:.1f}s")
|
||
print(f"{'='*70}")
|
||
|
||
|
||
# =====================================================================
|
||
# Minimal end-to-end test — single token "The" through the model
|
||
# =====================================================================
|
||
|
||
def minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks,
|
||
ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w,
|
||
final_norm_w, tokenizer):
|
||
"""Process a single token 'The' through the model and check output logits.
|
||
|
||
This is a focused diagnostic: if the model can't even produce reasonable
|
||
logits for a single token, something is fundamentally wrong in the
|
||
pipeline. We check:
|
||
1. No NaN/Inf in any layer output
|
||
2. Residual stream magnitude stays bounded
|
||
3. Top-5 logits are sensible (not all Chinese tokens for English)
|
||
4. Logit spread (max - min) is > 1.0 (not uniform)
|
||
"""
|
||
n_layers = cfg["num_hidden_layers"]
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
||
n_hc = 4
|
||
|
||
# Tokenize just "The"
|
||
tid = torch.tensor(tokenizer.encode("The"), dtype=torch.long, device='cuda:0')
|
||
if tid.numel() > 1:
|
||
# If tokenizer adds BOS, take last token
|
||
print(f" Note: 'The' tokenized to {tid.numel()} tokens, using last one")
|
||
tid = tid[-1:]
|
||
print(f" Token ID: {tid.item()} = '{tokenizer.decode(tid.tolist())}'")
|
||
|
||
# Setup
|
||
positions = torch.tensor([0], dtype=torch.long, device='cuda:0')
|
||
emb = embed(tid) # (1, H)
|
||
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
|
||
|
||
# Track per-layer diagnostics
|
||
layer_diags = []
|
||
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
dev = f"cuda:{gpu}"
|
||
if X.device != torch.device(dev):
|
||
X = X.to(dev)
|
||
torch.cuda.set_device(gpu)
|
||
|
||
w = layer_weights[li]
|
||
|
||
attn_mhc = attn_mhc_blocks.get(li)
|
||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||
a_norm = attn_norms[li]
|
||
f_norm = ffn_norms[li]
|
||
rc, rs = rope_caches[gpu]
|
||
kv_cache = SimpleKVCache(head_dim=hd, max_seq=8192, device=dev)
|
||
|
||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||
kv_cache, tid, positions)
|
||
|
||
# Per-layer diagnostic
|
||
x_max = X.abs().max().item()
|
||
has_nan = torch.isnan(X.float()).any().item()
|
||
has_inf = torch.isinf(X.float()).any().item()
|
||
# Stream 0 (primary)
|
||
x0 = X[:, 0, :]
|
||
x0_mean = x0.float().abs().mean().item()
|
||
x0_std = x0.float().std().item()
|
||
layer_diags.append({
|
||
'layer': li, 'gpu': gpu, 'x_max': x_max,
|
||
'x0_mean': x0_mean, 'x0_std': x0_std,
|
||
'nan': has_nan, 'inf': has_inf
|
||
})
|
||
|
||
if has_nan or has_inf:
|
||
print(f" ❌ Layer {li}: NaN={has_nan} Inf={has_inf} — STOPPING")
|
||
break
|
||
|
||
X = X.to('cuda:0')
|
||
torch.cuda.set_device(0)
|
||
|
||
# Final norm + lm_head
|
||
x_out = X[:, 0, :]
|
||
if final_norm_w is not None:
|
||
xf = x_out.float()
|
||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
x_out = (xf * rms * final_norm_w.float()).bfloat16()
|
||
logits = torch.nn.functional.linear(x_out, lm_w)
|
||
|
||
# Results
|
||
print(f"\n === Minimal E2E Test Results ===")
|
||
print(f" Logits: min={logits.float().min().item():.2f} max={logits.float().max().item():.2f} "
|
||
f"spread={logits.float().max().item() - logits.float().min().item():.2f}")
|
||
print(f" NaN={torch.isnan(logits.float()).any().item()} "
|
||
f"Inf={torch.isinf(logits.float()).any().item()}")
|
||
|
||
top10_vals, top10_ids = torch.topk(logits[0], 10)
|
||
print(f" Top-10 predictions:")
|
||
for i, (tid_v, val) in enumerate(zip(top10_ids, top10_vals)):
|
||
tok_str = tokenizer.decode([tid_v.item()])
|
||
print(f" {i+1}. '{tok_str}' (id={tid_v.item()}, logit={val.item():.3f})")
|
||
|
||
# Print residual stream evolution
|
||
print(f"\n Residual stream evolution (stream 0):")
|
||
for d in layer_diags[::5]: # Every 5th layer
|
||
print(f" L{d['layer']:2d}: |X|={d['x_max']:.1f} "
|
||
f"mean={d['x0_mean']:.1f} std={d['x0_std']:.1f} "
|
||
f"nan={d['nan']} inf={d['inf']}")
|
||
# Always print last
|
||
if layer_diags:
|
||
d = layer_diags[-1]
|
||
print(f" L{d['layer']:2d}: |X|={d['x_max']:.1f} "
|
||
f"mean={d['x0_mean']:.1f} std={d['x0_std']:.1f} "
|
||
f"nan={d['nan']} inf={d['inf']}")
|
||
|
||
# Check for reasonable output
|
||
spread = logits.float().max().item() - logits.float().min().item()
|
||
if spread < 1.0:
|
||
print(f" ⚠️ Logit spread {spread:.2f} is very low — model is essentially uniform")
|
||
else:
|
||
print(f" ✓ Logit spread {spread:.2f} looks reasonable")
|
||
|
||
return logits, layer_diags
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|