Files
nvfp4-megamoe-kernel/single_shot_inference.py

570 lines
23 KiB
Python

#!/usr/bin/env python3
"""Single-shot DSV4 inference — 8-GPU with mHC + MoE + KV cache.
Full model forward pass with all architectural components:
- mHC (Manifold-Constrained Hyper-Connections)
- Q low-rank + KV projection
- RoPE (partial, last 64 dims)
- Production FMHA kernel (tcgen05 MMA)
- Grouped output projection (wo_a BMM + wo_b NVFP4)
- Routed MoE (384 experts, top-6, hash + dense routing)
- Shared expert FFN (SwiGLU with clamping)
- KV cache across decode steps
Usage (on B200):
source /root/dsv4-nvfp4-workspace/venv/bin/activate
cd /root/dsv4-nvfp4-workspace/kernel
python3 single_shot_inference.py
"""
import os, sys, time, json, math
import torch
from pathlib import Path
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
MAX_NEW_TOKENS = 10
PROMPT = "The capital of France is"
NUM_GPUS = 8
# =====================================================================
# NVFP4 dequantization
# =====================================================================
FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.])
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
out_dim = weight.shape[0]
in_packed = weight.shape[1]
in_features = in_packed * 2
low = (weight & 0x0F).to(torch.int8)
high = (weight >> 4).to(torch.int8)
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
scale_f = weight_scale.float() * weight_scale_2.float()
scale_expanded = scale_f.repeat_interleave(16, dim=1)
return (w_f * scale_expanded).bfloat16()
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
return torch.nn.functional.linear(x, w)
def bf16_linear(x, weight):
return torch.nn.functional.linear(x, weight.bfloat16())
# =====================================================================
# mHC
# =====================================================================
class mHCBlock:
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_repeat=20, device='cuda'):
self.n_hc = n_hc
self.K = n_hc * hidden_dim
self.sinkhorn_repeat = sinkhorn_repeat
self.device = device
self.fn = None
self.hc_scale = None
self.hc_base = None
self.rms_eps = 1e-6
self.hc_pre_eps = 0.0
self.hc_sinkhorn_eps = 1e-6
self.hc_post_mult_value = 2.0
def load_from_checkpoint(self, fn, base, scale):
self.fn = fn.to(device=self.device, dtype=torch.float32).contiguous()
self.hc_base = base.to(device=self.device, dtype=torch.float32).contiguous()
self.hc_scale = scale.to(device=self.device, dtype=torch.float32).contiguous()
def pre_block(self, residual):
n = self.n_hc
K = self.K
T = residual.shape[0]
res_flat = residual.reshape(T, K).float()
mixes = torch.matmul(res_flat, self.fn.t())
sqrsum = res_flat.square().sum(dim=-1, keepdim=True)
mixes = mixes * torch.rsqrt(sqrsum / K + self.rms_eps)
pre_logits = mixes[:, :n] * self.hc_scale[0] + self.hc_base[:n]
pre_mix = torch.sigmoid(pre_logits) + self.hc_pre_eps
post_logits = mixes[:, n:2*n] * self.hc_scale[1] + self.hc_base[n:2*n]
post_mix = torch.sigmoid(post_logits) * self.hc_post_mult_value
comb_logits = (mixes[:, 2*n:].reshape(T, n, n) * self.hc_scale[2]
+ self.hc_base[2*n:].reshape(1, n, n))
comb_mix = torch.softmax(comb_logits, dim=-1) + self.hc_sinkhorn_eps
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + self.hc_sinkhorn_eps)
for _ in range(self.sinkhorn_repeat - 1):
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + self.hc_sinkhorn_eps)
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + self.hc_sinkhorn_eps)
layer_input = (pre_mix.unsqueeze(-1) * residual.float()).sum(dim=1).bfloat16()
return layer_input, (post_mix, comb_mix)
def post_block(self, residual, F_out, ctx):
post_mix, comb_mix = ctx
mixed_residual = torch.einsum('tij,tjh->tjh', comb_mix, residual.float())
post_term = post_mix.unsqueeze(-1) * F_out.unsqueeze(1).float()
residual_next = mixed_residual + post_term
# Emergency RMSNorm (remove once MoE provides balance)
_T = residual_next.shape[0]
rn_f = residual_next.reshape(_T, self.n_hc, -1)
rms = rn_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
return (rn_f * rms).bfloat16()
# =====================================================================
# RoPE
# =====================================================================
def build_rope_cache(max_pos, head_dim, rope_dim, device, theta=10000.0):
half = rope_dim // 2
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
return torch.cos(angles).to(device), torch.sin(angles).to(device)
def apply_rope(x, positions, cos_cache, sin_cache, rope_dim):
T, n_h, hd = x.shape
nope = hd - rope_dim
cos = cos_cache[positions].unsqueeze(1).to(x.dtype)
sin = sin_cache[positions].unsqueeze(1).to(x.dtype)
out = x.clone()
out[:, :, nope:][..., 0::2] = x[:, :, nope:][..., 0::2] * cos - x[:, :, nope:][..., 1::2] * sin
out[:, :, nope:][..., 1::2] = x[:, :, nope:][..., 0::2] * sin + x[:, :, nope:][..., 1::2] * cos
return out
# =====================================================================
# Simple KV cache — BF16 flat, one tensor per layer
# =====================================================================
class SimpleKVCache:
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
MQA: 1 KV head, so cache is (1, seq_len, hd) per layer."""
def __init__(self, n_layers, head_dim, max_seq=8192, device='cuda:0'):
self.hd = head_dim
self.max_seq = max_seq
self.device = device
self.k = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)]
self.v = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)]
self.len = [0] * n_layers # current sequence length per layer
def append(self, layer_idx, k_new, v_new):
"""Append new K,V. k_new: (1, 1, hd), v_new: (1, 1, hd)."""
pos = self.len[layer_idx]
self.k[layer_idx][0, pos] = k_new[0, 0]
self.v[layer_idx][0, pos] = v_new[0, 0]
self.len[layer_idx] = pos + 1
def get(self, layer_idx):
"""Get K,V up to current length. Returns (1, seq_len, hd), (1, seq_len, hd)."""
l = self.len[layer_idx]
return self.k[layer_idx][:, :l], self.v[layer_idx][:, :l]
# =====================================================================
# Routed MoE forward
# =====================================================================
def moe_forward(x, w, li, cfg, token_id):
"""Run routed MoE + shared expert.
x: (1, H) BF16 — input after FFN mHC pre_block
Returns: (1, H) BF16 — combined expert output
"""
H = cfg["hidden_size"]
n_experts = cfg["n_routed_experts"] # 384
top_k = cfg["num_experts_per_tok"] if "num_experts_per_tok" in cfg else 6
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
swiglu_limit = cfg.get("swiglu_limit", 10.0)
mlp_inter = cfg["moe_intermediate_size"] # 3072
# ---- Hash routing ----
# For decode, first 3 layers use hash routing (token ID lookup)
# Remaining layers use dense routing (weight projection)
is_hash = li < 3 # Hash routing for first 3 layers
expert_ids = None
expert_weights = None
if is_hash:
# tid2eid: (vocab_size, top_k) int64
tid2eid = w[f"model.layers.{li}.mlp.gate.tid2eid"]
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
expert_ids = tid2eid[tid] # (top_k,) int64
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
else:
# Dense routing: gate weight (n_experts, H) BF16
gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (384, 7168) BF16
logits = bf16_linear(x, gate_w) # (1, 384) BF16
# activation = sqrt(softplus(logits))
activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6).bfloat16()
# Top-k
scores, indices = activated.float().topk(top_k, dim=-1) # (1, 6)
expert_ids = indices[0] # (6,)
# Renormalize
expert_weights = (scores[0] / scores[0].sum()).float()
# ---- Run selected experts ----
expert_outputs = []
for i, eid in enumerate(expert_ids):
eid_int = eid.item()
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
gate = nvfp4_linear(x, w[f"{epre}.gate_proj.weight"],
w[f"{epre}.gate_proj.weight_scale"],
w[f"{epre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x, w[f"{epre}.up_proj.weight"],
w[f"{epre}.up_proj.weight_scale"],
w[f"{epre}.up_proj.weight_scale_2"])
# SiLU + clamp
silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit)
hidden = (silu_out * up.float()).bfloat16()
down = nvfp4_linear(hidden, w[f"{epre}.down_proj.weight"],
w[f"{epre}.down_proj.weight_scale"],
w[f"{epre}.down_proj.weight_scale_2"])
expert_outputs.append(down)
# Weighted combine + scaling
routed_out = torch.zeros_like(x)
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
routed_out = routed_out + (out.float() * wt).bfloat16()
routed_out = (routed_out.float() * routed_scaling).bfloat16()
# ---- Shared expert ----
se_pre = f"model.layers.{li}.mlp.shared_experts"
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
if se_gate_w is not None:
gate = nvfp4_linear(x, se_gate_w,
w[f"{se_pre}.gate_proj.weight_scale"],
w[f"{se_pre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"],
w[f"{se_pre}.up_proj.weight_scale"],
w[f"{se_pre}.up_proj.weight_scale_2"])
silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit)
hidden = (silu_out * up.float()).bfloat16()
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"],
w[f"{se_pre}.down_proj.weight_scale"],
w[f"{se_pre}.down_proj.weight_scale_2"])
else:
shared_out = torch.zeros_like(x)
return routed_out + shared_out
# =====================================================================
# Weight loading
# =====================================================================
def load_all_weights(checkpoint_dir, num_layers):
from safetensors.torch import load_file
cdir = Path(checkpoint_dir)
index_path = cdir / "model.safetensors.index.json"
weight_map = {}
if index_path.exists():
with open(index_path) as f:
weight_map = json.load(f).get("weight_map", {})
shard_names = set(weight_map.values()) if weight_map else {
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
}
print(f"Loading {len(shard_names)} shards...")
all_weights = {}
loaded = 0
for shard_name in sorted(shard_names):
if not (cdir / shard_name).exists():
continue
data = load_file(str(cdir / shard_name))
all_weights.update(data)
loaded += 1
if loaded % 20 == 0:
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
print(f" Done: {len(all_weights)} tensors")
layer_weights = {}
global_weights = {}
print("Assigning to GPUs...")
for key, tensor in all_weights.items():
if key.startswith("model.layers."):
li = int(key.split(".")[2])
target_gpu = li % NUM_GPUS
target_device = f"cuda:{target_gpu}"
if li not in layer_weights:
layer_weights[li] = {"_device": target_device, "_gpu": target_gpu}
layer_weights[li][key] = tensor.to(target_device)
elif key.startswith("model.embed_tokens"):
global_weights[key] = tensor.to("cuda:0")
elif key.startswith("model.norm"):
global_weights[key] = tensor.to("cuda:0")
elif key.startswith("lm_head"):
global_weights[key] = tensor.to("cuda:0")
for gpu in range(NUM_GPUS):
alloc = torch.cuda.memory_allocated(gpu) / 1e9
print(f" GPU {gpu}: {alloc:.1f}GB")
return layer_weights, global_weights
# =====================================================================
# Single layer forward
# =====================================================================
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc,
kv_cache, token_id, decode_pos):
"""Forward one layer with mHC + MoE + KV cache."""
device = X_l.device
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg["qk_rope_head_dim"]
o_rank = cfg["o_lora_rank"]
o_groups = cfg["o_groups"]
n_hc = 4
pre = f"model.layers.{li}.self_attn"
T = X_l.shape[0]
heads_per_group = n_h // o_groups
group_input_dim = heads_per_group * hd
# ==== mHC pre_block (attention) ====
x_in, attn_ctx = attn_mhc.pre_block(X_l)
# ==== Q projection ====
c_Q = nvfp4_linear(x_in, w[f"{pre}.q_a_proj.weight"],
w[f"{pre}.q_a_proj.weight_scale"],
w[f"{pre}.q_a_proj.weight_scale_2"])
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"],
w[f"{pre}.q_b_proj.weight_scale"],
w[f"{pre}.q_b_proj.weight_scale_2"])
# ==== KV projection ====
kv = nvfp4_linear(x_in, w[f"{pre}.kv_proj.weight"],
w[f"{pre}.kv_proj.weight_scale"],
w[f"{pre}.kv_proj.weight_scale_2"])
# ==== Reshape for attention ====
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
k_new = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd)
v_new = k_new.clone()
# ==== KV cache: append new K,V ====
kv_cache.append(0, k_new, v_new)
k_full, v_full = kv_cache.get(0) # (1, seq_len, hd)
seq_len = k_full.shape[1]
# ==== RoPE ====
# Apply RoPE to Q (at current position)
q_pos = torch.tensor([decode_pos], dtype=torch.long, device=device)
q_heads = apply_rope(q_heads, q_pos, rope_cos, rope_sin, rd)
# Apply RoPE to K (at each position in the cache)
k_positions = torch.arange(seq_len, dtype=torch.long, device=device)
k_full_3d = k_full.permute(1, 0, 2) # (seq_len, 1, hd) for RoPE
k_full_3d = apply_rope(k_full_3d, k_positions, rope_cos, rope_sin, rd)
k_full = k_full_3d.permute(1, 0, 2) # (1, seq_len, hd) — RoPE'd
# ==== FMHA ====
from dsv4.kernels.attention.production import dsv4_attention
attn_out = dsv4_attention(q_heads, k_full, v_full)
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd)
# ==== Output projection ====
attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd).reshape(T, o_groups, group_input_dim)
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
attn_for_bmm = attn_grouped.permute(1, 0, 2)
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"],
w[f"{pre}.o_b_proj.weight_scale"],
w[f"{pre}.o_b_proj.weight_scale_2"])
# ==== mHC post_block (attention) ====
X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx)
# ==== mHC pre_block (FFN) ====
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_l)
# ==== MoE + shared expert ====
F_ffn = moe_forward(x_ffn, w, li, cfg, token_id)
# ==== mHC post_block (FFN) ====
X_l = ffn_mhc.post_block(X_l, F_ffn, ffn_ctx)
return X_l
# =====================================================================
# Main
# =====================================================================
def main():
t_start = time.time()
print("=" * 70)
print("DSV4 Single-Shot Inference — Full Pipeline (mHC+MoE+KV)")
print("=" * 70)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg["qk_rope_head_dim"]
n_hc = 4
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
# ==== Phase 1: Load weights ====
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers)
t_loaded = time.time()
print(f"Weight loading: {t_loaded - t_start:.1f}s")
# ==== Build mHC blocks ====
print("Building mHC blocks...")
attn_mhc_blocks = {}
ffn_mhc_blocks = {}
for li in range(n_layers):
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
fn = layer_weights[li].get(f"{prefix}.fn")
base = layer_weights[li].get(f"{prefix}.base")
scale = layer_weights[li].get(f"{prefix}.scale")
if fn is not None and base is not None and scale is not None:
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
mhc.load_from_checkpoint(fn, base, scale)
blocks[li] = mhc
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
# ==== Global weights ====
torch.cuda.set_device(0)
embed_w = global_weights.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
final_norm_w = global_weights.get("model.norm.weight")
rope_caches = {g: build_rope_cache(8192, hd, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
# ==== KV cache (gpu0, moves to target GPU per layer) ====
kv_caches = {}
for li in range(n_layers):
kv_caches[li] = SimpleKVCache(n_layers=1, head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
# ==== Phase 2: Compile ====
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
from dsv4.kernels.attention.production import dsv4_attention
dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0')
dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0')
try:
_ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone())
print(" FMHA: compiled OK")
except Exception as e:
print(f" FMHA error: {e}")
t_compiled = time.time()
print(f"Compile: {t_compiled - t_loaded:.1f}s")
# ==== Phase 3: Inference ====
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
print(f"Prompt: '{PROMPT}'{input_ids.tolist()}")
generated = input_ids[0].tolist()
# ==== Prefill: process prompt tokens to fill KV cache ====
print(f"Prefilling {len(generated)} prompt tokens...")
for prefill_idx, tid_val in enumerate(generated):
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
emb = embed(tid)
X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone()
for li in range(n_layers):
gpu = li % NUM_GPUS
target_device = f"cuda:{gpu}"
if X.device != torch.device(target_device):
X = X.to(target_device)
torch.cuda.set_device(gpu)
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
attn_mhc, ffn_mhc, kv_caches[li], tid, prefill_idx)
X = X.to('cuda:0')
torch.cuda.set_device(0)
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
# ==== Decode: generate new tokens ====
print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...")
for step in range(MAX_NEW_TOKENS):
t0 = time.time()
# Current token (last in the sequence)
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
decode_pos = len(all_tokens) - 1 # absolute position
# Embed → mHC init state
emb = embed(tid) # (1, H) on gpu0
X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, n_hc, H)
# Process layers
for li in range(n_layers):
gpu = li % NUM_GPUS
target_device = f"cuda:{gpu}"
if X.device != torch.device(target_device):
X = X.to(target_device)
torch.cuda.set_device(gpu)
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
attn_mhc, ffn_mhc, kv_caches[li], tid, decode_pos)
# Back to gpu0
X = X.to('cuda:0')
torch.cuda.set_device(0)
# Read out stream 0 → RMSNorm → lm_head
x_out = X[:, 0, :]
if final_norm_w is not None:
xf = x_out.float()
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x_out = (xf * rms * final_norm_w.float()).bfloat16()
logits = torch.nn.functional.linear(x_out, lm_w)
next_id = torch.argmax(logits, dim=-1).item()
generated.append(next_id)
all_tokens.append(next_id)
tok_str = tokenizer.decode([next_id])
dt = time.time() - t0
has_nan = torch.isnan(logits.float()).any().item()
lmin, lmax = logits.float().min().item(), logits.float().max().item()
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan}")
if has_nan:
print(" NaN — stopping")
break
if next_id == tokenizer.eos_token_id:
break
out = tokenizer.decode(generated, skip_special_tokens=True)
total = time.time() - t_start
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output: '{out}'")
print(f"Total: {total:.1f}s")
print(f"{'='*70}")
if __name__ == "__main__":
main()