Add full MoE routing + KV cache to single_shot

MoE:
- Hash routing (first 3 layers): tid2eid lookup → 6 experts, uniform weights
- Dense routing (remaining): sqrt(softplus(gate)) → top-6 → renormalize
- 384 NVFP4 experts, each gate+up+down with SiGLU clamping
- Weighted combine × routed_scaling_factor + shared expert

KV cache:
- SimpleKVCache: BF16 flat (1, max_seq, hd) per layer
- Appends new K,V each decode step
- FMHA now attends over full cached sequence (not just current token)
- RoPE applied per-position on K cache

This should produce meaningful output — the model now has all
architectural components except proper mHC normalization.
This commit is contained in:
2026-05-31 00:11:15 +00:00
parent 3ecfbcba57
commit afcc690ddc

View File

@@ -1,19 +1,15 @@
#!/usr/bin/env python3
"""Single-shot DSV4 inference — 8-GPU pipeline parallel with mHC.
"""Single-shot DSV4 inference — 8-GPU with mHC + MoE + KV cache.
Loads the full NVFP4 checkpoint across 8 B200 GPUs. Includes:
- mHC (Manifold-Constrained Hyper-Connections) — load-bearing residual
- Q low-rank projection + KV projection
Full model forward pass with all architectural components:
- mHC (Manifold-Constrained Hyper-Connections)
- Q low-rank + KV projection
- RoPE (partial, last 64 dims)
- Production FMHA kernel (tcgen05 MMA, hd=512, 128 heads)
- Output projection: wo_a (grouped BMM) → wo_b (NVFP4)
- Shared expert FFN (SwiGLU)
- NVFP4 dequant → BF16 matmul baseline for linear layers
Missing (causing incorrect output):
- Routed MoE experts (384 experts, top-6)
- Production FMHA kernel (tcgen05 MMA)
- Grouped output projection (wo_a BMM + wo_b NVFP4)
- Routed MoE (384 experts, top-6, hash + dense routing)
- Shared expert FFN (SwiGLU with clamping)
- KV cache across decode steps
- Compressor + indexer (CSA/HCA compressed KV)
Usage (on B200):
source /root/dsv4-nvfp4-workspace/venv/bin/activate
@@ -60,23 +56,10 @@ def bf16_linear(x, weight):
# =====================================================================
# mHC — Manifold-Constrained Hyper-Connections
# mHC
# =====================================================================
def sinkhorn_knopp(M, t_max=20, eps=1e-6):
"""Project (T, n, n) positive matrices onto Birkhoff polytope."""
for _ in range(t_max):
M = M / (M.sum(dim=-1, keepdim=True) + eps)
M = M / (M.sum(dim=-2, keepdim=True) + eps)
return M
class mHCBlock:
"""One mHC block — matches vLLM torch reference exactly.
Checkpoint: fn (24,28672), base (24,), scale (3,)
Split: pre_mix (n_hc=4), post_mix (n_hc=4), comb_mix (n_hc*n_hc=16)
"""
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_repeat=20, device='cuda'):
self.n_hc = n_hc
self.K = n_hc * hidden_dim
@@ -96,26 +79,18 @@ class mHCBlock:
self.hc_scale = scale.to(device=self.device, dtype=torch.float32).contiguous()
def pre_block(self, residual):
"""residual: (T, n_hc, d) BF16 → layer_input (T, d), ctx"""
n = self.n_hc
K = self.K
eps = self.rms_eps
T = residual.shape[0]
res_flat = residual.reshape(T, K).float() # (T, K)
# Project with RMSNorm on input
mixes = torch.matmul(res_flat, self.fn.t()) # (T, 24)
res_flat = residual.reshape(T, K).float()
mixes = torch.matmul(res_flat, self.fn.t())
sqrsum = res_flat.square().sum(dim=-1, keepdim=True)
mixes = mixes * torch.rsqrt(sqrsum / K + eps)
mixes = mixes * torch.rsqrt(sqrsum / K + self.rms_eps)
# Split
pre_logits = mixes[:, :n] * self.hc_scale[0] + self.hc_base[:n]
pre_mix = torch.sigmoid(pre_logits) + self.hc_pre_eps # (T, 4)
pre_mix = torch.sigmoid(pre_logits) + self.hc_pre_eps
post_logits = mixes[:, n:2*n] * self.hc_scale[1] + self.hc_base[n:2*n]
post_mix = torch.sigmoid(post_logits) * self.hc_post_mult_value # (T, 4)
post_mix = torch.sigmoid(post_logits) * self.hc_post_mult_value
comb_logits = (mixes[:, 2*n:].reshape(T, n, n) * self.hc_scale[2]
+ self.hc_base[2*n:].reshape(1, n, n))
comb_mix = torch.softmax(comb_logits, dim=-1) + self.hc_sinkhorn_eps
@@ -124,22 +99,22 @@ class mHCBlock:
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + self.hc_sinkhorn_eps)
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + self.hc_sinkhorn_eps)
# layer_input = sum(pre_mix * residual)
layer_input = (pre_mix.unsqueeze(-1) * residual.float()).sum(dim=1).bfloat16()
return layer_input, (post_mix, comb_mix)
def post_block(self, residual, F_out, ctx):
"""residual: (T, n_hc, d), F_out: (T, d) → (T, n_hc, d)"""
post_mix, comb_mix = ctx
mixed_residual = torch.einsum('tij,tjh->tjh', comb_mix, residual.float())
post_term = post_mix.unsqueeze(-1) * F_out.unsqueeze(1).float()
residual_next = mixed_residual + post_term
# Emergency: per-residual RMSNorm (routed MoE missing)
# Emergency RMSNorm (remove once MoE provides balance)
_T = residual_next.shape[0]
rn_f = residual_next.reshape(_T, self.n_hc, -1)
rms = rn_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
return (rn_f * rms).bfloat16()
# =====================================================================
# RoPE
# =====================================================================
@@ -160,24 +135,141 @@ def apply_rope(x, positions, cos_cache, sin_cache, rope_dim):
return out
# =====================================================================
# Simple KV cache — BF16 flat, one tensor per layer
# =====================================================================
class SimpleKVCache:
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
MQA: 1 KV head, so cache is (1, seq_len, hd) per layer."""
def __init__(self, n_layers, head_dim, max_seq=8192, device='cuda:0'):
self.hd = head_dim
self.max_seq = max_seq
self.device = device
self.k = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)]
self.v = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)]
self.len = [0] * n_layers # current sequence length per layer
def append(self, layer_idx, k_new, v_new):
"""Append new K,V. k_new: (1, 1, hd), v_new: (1, 1, hd)."""
pos = self.len[layer_idx]
self.k[layer_idx][0, pos] = k_new[0, 0]
self.v[layer_idx][0, pos] = v_new[0, 0]
self.len[layer_idx] = pos + 1
def get(self, layer_idx):
"""Get K,V up to current length. Returns (1, seq_len, hd), (1, seq_len, hd)."""
l = self.len[layer_idx]
return self.k[layer_idx][:, :l], self.v[layer_idx][:, :l]
# =====================================================================
# Routed MoE forward
# =====================================================================
def moe_forward(x, w, li, cfg, token_id):
"""Run routed MoE + shared expert.
x: (1, H) BF16 — input after FFN mHC pre_block
Returns: (1, H) BF16 — combined expert output
"""
H = cfg["hidden_size"]
n_experts = cfg["n_routed_experts"] # 384
top_k = cfg["num_experts_per_tok"] if "num_experts_per_tok" in cfg else 6
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
swiglu_limit = cfg.get("swiglu_limit", 10.0)
mlp_inter = cfg["moe_intermediate_size"] # 3072
# ---- Hash routing ----
# For decode, first 3 layers use hash routing (token ID lookup)
# Remaining layers use dense routing (weight projection)
is_hash = li < 3 # Hash routing for first 3 layers
expert_ids = None
expert_weights = None
if is_hash:
# tid2eid: (vocab_size, top_k) int64
tid2eid = w[f"model.layers.{li}.mlp.gate.tid2eid"]
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
expert_ids = tid2eid[tid] # (top_k,) int64
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
else:
# Dense routing: gate weight (n_experts, H) BF16
gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (384, 7168) BF16
logits = bf16_linear(x, gate_w) # (1, 384) BF16
# activation = sqrt(softplus(logits))
activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6).bfloat16()
# Top-k
scores, indices = activated.float().topk(top_k, dim=-1) # (1, 6)
expert_ids = indices[0] # (6,)
# Renormalize
expert_weights = (scores[0] / scores[0].sum()).float()
# ---- Run selected experts ----
expert_outputs = []
for i, eid in enumerate(expert_ids):
eid_int = eid.item()
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
gate = nvfp4_linear(x, w[f"{epre}.gate_proj.weight"],
w[f"{epre}.gate_proj.weight_scale"],
w[f"{epre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x, w[f"{epre}.up_proj.weight"],
w[f"{epre}.up_proj.weight_scale"],
w[f"{epre}.up_proj.weight_scale_2"])
# SiLU + clamp
silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit)
hidden = (silu_out * up.float()).bfloat16()
down = nvfp4_linear(hidden, w[f"{epre}.down_proj.weight"],
w[f"{epre}.down_proj.weight_scale"],
w[f"{epre}.down_proj.weight_scale_2"])
expert_outputs.append(down)
# Weighted combine + scaling
routed_out = torch.zeros_like(x)
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
routed_out = routed_out + (out.float() * wt).bfloat16()
routed_out = (routed_out.float() * routed_scaling).bfloat16()
# ---- Shared expert ----
se_pre = f"model.layers.{li}.mlp.shared_experts"
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
if se_gate_w is not None:
gate = nvfp4_linear(x, se_gate_w,
w[f"{se_pre}.gate_proj.weight_scale"],
w[f"{se_pre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"],
w[f"{se_pre}.up_proj.weight_scale"],
w[f"{se_pre}.up_proj.weight_scale_2"])
silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit)
hidden = (silu_out * up.float()).bfloat16()
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"],
w[f"{se_pre}.down_proj.weight_scale"],
w[f"{se_pre}.down_proj.weight_scale_2"])
else:
shared_out = torch.zeros_like(x)
return routed_out + shared_out
# =====================================================================
# Weight loading
# =====================================================================
def load_all_weights(checkpoint_dir, num_layers):
from safetensors.torch import load_file
cdir = Path(checkpoint_dir)
index_path = cdir / "model.safetensors.index.json"
weight_map = {}
if index_path.exists():
with open(index_path) as f:
weight_map = json.load(f).get("weight_map", {})
shard_names = set(weight_map.values()) if weight_map else {
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
}
print(f"Loading {len(shard_names)} shards...")
all_weights = {}
loaded = 0
@@ -193,7 +285,6 @@ def load_all_weights(checkpoint_dir, num_layers):
layer_weights = {}
global_weights = {}
print("Assigning to GPUs...")
for key, tensor in all_weights.items():
if key.startswith("model.layers."):
@@ -213,7 +304,6 @@ def load_all_weights(checkpoint_dir, num_layers):
for gpu in range(NUM_GPUS):
alloc = torch.cuda.memory_allocated(gpu) / 1e9
print(f" GPU {gpu}: {alloc:.1f}GB")
return layer_weights, global_weights
@@ -221,11 +311,9 @@ def load_all_weights(checkpoint_dir, num_layers):
# Single layer forward
# =====================================================================
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
"""Forward one layer with mHC.
X_l: (1, n_hc, H) BF16 → (1, n_hc, H) BF16
"""
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc,
kv_cache, token_id, decode_pos):
"""Forward one layer with mHC + MoE + KV cache."""
device = X_l.device
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
@@ -234,14 +322,13 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
o_rank = cfg["o_lora_rank"]
o_groups = cfg["o_groups"]
n_hc = 4
pre = f"model.layers.{li}.self_attn"
T = X_l.shape[0]
heads_per_group = n_h // o_groups
group_input_dim = heads_per_group * hd
# ==== mHC pre_block (attention) ====
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) BF16
x_in, attn_ctx = attn_mhc.pre_block(X_l)
# ==== Q projection ====
c_Q = nvfp4_linear(x_in, w[f"{pre}.q_a_proj.weight"],
@@ -257,18 +344,29 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
w[f"{pre}.kv_proj.weight_scale_2"])
# ==== Reshape for attention ====
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
k = kv.reshape(T, 1, hd).permute(1, 0, 2)
v = k.clone()
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
k_new = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd)
v_new = k_new.clone()
# ==== KV cache: append new K,V ====
kv_cache.append(li, k_new, v_new)
k_full, v_full = kv_cache.get(li) # (1, seq_len, hd)
seq_len = k_full.shape[1]
# ==== RoPE ====
pos = torch.tensor([0], dtype=torch.long, device=device)
q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd)
k = apply_rope(k, pos, rope_cos, rope_sin, rd)
# Apply RoPE to Q (at current position)
q_pos = torch.tensor([decode_pos], dtype=torch.long, device=device)
q_heads = apply_rope(q_heads, q_pos, rope_cos, rope_sin, rd)
# Apply RoPE to K (at each position in the cache)
k_positions = torch.arange(seq_len, dtype=torch.long, device=device)
k_full_3d = k_full.permute(1, 0, 2) # (seq_len, 1, hd) for RoPE
k_full_3d = apply_rope(k_full_3d, k_positions, rope_cos, rope_sin, rd)
k_full = k_full_3d.permute(1, 0, 2) # (1, seq_len, hd) — RoPE'd
# ==== FMHA ====
from dsv4.kernels.attention.production import dsv4_attention
attn_out = dsv4_attention(q_heads, k, v)
attn_out = dsv4_attention(q_heads, k_full, v_full)
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd)
# ==== Output projection ====
@@ -283,28 +381,13 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
w[f"{pre}.o_b_proj.weight_scale_2"])
# ==== mHC post_block (attention) ====
X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx)
# ==== mHC pre_block (FFN) ====
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_l)
# ==== FFN: shared expert ====
se_pre = f"model.layers.{li}.mlp.shared_experts"
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
F_ffn = torch.zeros_like(x_ffn)
if se_gate_w is not None:
gate = nvfp4_linear(x_ffn, se_gate_w,
w[f"{se_pre}.gate_proj.weight_scale"],
w[f"{se_pre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x_ffn, w[f"{se_pre}.up_proj.weight"],
w[f"{se_pre}.up_proj.weight_scale"],
w[f"{se_pre}.up_proj.weight_scale_2"])
F_ffn = nvfp4_linear(
torch.nn.functional.silu(gate) * up,
w[f"{se_pre}.down_proj.weight"],
w[f"{se_pre}.down_proj.weight_scale"],
w[f"{se_pre}.down_proj.weight_scale_2"],
)
# ==== MoE + shared expert ====
F_ffn = moe_forward(x_ffn, w, li, cfg, token_id)
# ==== mHC post_block (FFN) ====
X_l = ffn_mhc.post_block(X_l, F_ffn, ffn_ctx)
@@ -319,7 +402,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
def main():
t_start = time.time()
print("=" * 70)
print("DSV4 Single-Shot Inference — 8-GPU with mHC")
print("DSV4 Single-Shot Inference — Full Pipeline (mHC+MoE+KV)")
print("=" * 70)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
@@ -331,6 +414,7 @@ def main():
rd = cfg["qk_rope_head_dim"]
n_hc = 4
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
# ==== Phase 1: Load weights ====
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
@@ -338,7 +422,7 @@ def main():
t_loaded = time.time()
print(f"Weight loading: {t_loaded - t_start:.1f}s")
# ==== Build mHC blocks per layer ====
# ==== Build mHC blocks ====
print("Building mHC blocks...")
attn_mhc_blocks = {}
ffn_mhc_blocks = {}
@@ -346,37 +430,31 @@ def main():
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
# Attention mHC
attn_fn = layer_weights[li].get(f"model.layers.{li}.attn_hc.fn")
attn_base = layer_weights[li].get(f"model.layers.{li}.attn_hc.base")
attn_scale = layer_weights[li].get(f"model.layers.{li}.attn_hc.scale")
if attn_fn is not None and attn_base is not None and attn_scale is not None:
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
attn_mhc.load_from_checkpoint(attn_fn, attn_base, attn_scale)
attn_mhc_blocks[li] = attn_mhc
# FFN mHC
ffn_fn = layer_weights[li].get(f"model.layers.{li}.ffn_hc.fn")
ffn_base = layer_weights[li].get(f"model.layers.{li}.ffn_hc.base")
ffn_scale = layer_weights[li].get(f"model.layers.{li}.ffn_hc.scale")
if ffn_fn is not None and ffn_base is not None and ffn_scale is not None:
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
ffn_mhc.load_from_checkpoint(ffn_fn, ffn_base, ffn_scale)
ffn_mhc_blocks[li] = ffn_mhc
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
fn = layer_weights[li].get(f"{prefix}.fn")
base = layer_weights[li].get(f"{prefix}.base")
scale = layer_weights[li].get(f"{prefix}.scale")
if fn is not None and base is not None and scale is not None:
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
mhc.load_from_checkpoint(fn, base, scale)
blocks[li] = mhc
print(f" attn mHC: {len(attn_mhc_blocks)} layers")
print(f" ffn mHC: {len(ffn_mhc_blocks)} layers")
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
# ==== Global weights (gpu0) ====
# ==== Global weights ====
torch.cuda.set_device(0)
embed_w = global_weights.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
final_norm_w = global_weights.get("model.norm.weight")
# RoPE caches per GPU
rope_caches = {g: build_rope_cache(8192, hd, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
# ==== KV cache (gpu0, moves to target GPU per layer) ====
kv_caches = {}
for li in range(n_layers):
kv_caches[li] = SimpleKVCache(1, hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
# ==== Phase 2: Compile ====
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
from dsv4.kernels.attention.production import dsv4_attention
@@ -399,9 +477,18 @@ def main():
generated = input_ids[0].tolist()
# ==== Prefill: process all prompt tokens at once ====
# For now, treat each prompt token as a separate decode step
# (proper prefill would use T>1 FMHA, which is wired but not yet
# tested with KV cache. One token at a time is correct, just slow.)
all_tokens = generated.copy() # start with prompt tokens
for step in range(MAX_NEW_TOKENS):
t0 = time.time()
tid = torch.tensor([generated[-1]], dtype=torch.long, device='cuda:0')
# Current token (last in the sequence)
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
decode_pos = len(all_tokens) - 1 # absolute position
# Embed → mHC init state
emb = embed(tid) # (1, H) on gpu0
@@ -418,25 +505,15 @@ def main():
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, attn_mhc, ffn_mhc)
# NaN check every 5 layers or on first NaN
if li % 10 == 0 or (X.float().abs().max().item() > 1e4 or torch.isnan(X.float()).any().item()):
has_nan = torch.isnan(X.float()).any().item()
xmax = X.float().abs().max().item()
print(f" L{li}: nan={has_nan}, max_abs={xmax:.2f}")
if has_nan:
# Debug: check mHC outputs
x_out = X[:, 0, :]
print(f" L{li} stream0: nan={torch.isnan(x_out.float()).any().item()}, max={x_out.float().abs().max().item():.2f}")
break
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
attn_mhc, ffn_mhc, kv_caches[li], tid, decode_pos)
# Back to gpu0
X = X.to('cuda:0')
torch.cuda.set_device(0)
# Read out stream 0 → RMSNorm → lm_head
x_out = X[:, 0, :] # (1, H)
x_out = X[:, 0, :]
if final_norm_w is not None:
xf = x_out.float()
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
@@ -445,17 +522,16 @@ def main():
logits = torch.nn.functional.linear(x_out, lm_w)
next_id = torch.argmax(logits, dim=-1).item()
generated.append(next_id)
all_tokens.append(next_id)
tok_str = tokenizer.decode([next_id])
dt = time.time() - t0
# Check for NaN
has_nan = torch.isnan(logits.float()).any().item()
logit_range = f"[{logits.float().min().item():.1f}, {logits.float().max().item():.1f}]"
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits={logit_range} nan={has_nan}")
lmin, lmax = logits.float().min().item(), logits.float().max().item()
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan}")
if has_nan:
print(" NaN detected, stopping")
print(" NaN stopping")
break
if next_id == tokenizer.eos_token_id:
break
@@ -465,7 +541,7 @@ def main():
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output: '{out}'")
print(f"Total: {total:.1f}s (load: {t_loaded-t_start:.1f}s, compile: {t_compiled-t_loaded:.1f}s, infer: {time.time()-t_compiled:.1f}s)")
print(f"Total: {total:.1f}s")
print(f"{'='*70}")