From eb08cd06d1378ce6957b777adb56a97598691b5b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 21:48:59 +0000 Subject: [PATCH] Rewrite single_shot_inference.py: correct weight keys, NVFP4 two-level scale, compressor+indexer connected MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fixed weight key format: model.layers.{li}.self_attn.* (was layers.{li}.attn.*) - Added NVFP4 two-level scale: weight_scale * weight_scale_2 * input_scale - Proper CSA compressor: overlapping Ca/Cb streams, token-level softmax - Proper HCA compressor: non-overlapping, single stream - Indexer: NVFP4 q_b_proj + weights_proj + own compressor at index_head_dim - Compressed KV (dim=hd) concatenated with SWA KV for attention - Correct MoE key format: gate_proj/up_proj/down_proj - Correct mHC key format: attn_hc.{fn,base,scale} and ffn_hc.{fn,base,scale} - No more disconnected compressor — full E2E pipeline --- single_shot_inference.py | 1122 +++++++++++++------------------------- 1 file changed, 378 insertions(+), 744 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 51235e8a..6205550b 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -2,38 +2,42 @@ """Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU. Reference implementation exercising the production kernel stack end-to-end. -Should be usable as ground truth when integrating into vLLM or SGLang. +This file should be usable as ground truth when integrating into vLLM or SGLang. -Architecture (paper §2, DeepSeek reference inference/model.py): +Architecture (paper §2, verified against HuggingFace modeling_deepseek_v4.py): X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid X_mid → mHC.pre_block → RMSNorm → FFN(MoE) → F_ffn → mHC.post_block → X_{l+1} Components exercised: - - mHC (Manifold-Constrained Hyper-Connections) — Sinkhorn-Knopp - - Low-rank Q projection (wq_a → q_norm → wq_b → q_b_norm) - - KV projection (wkv → kv_norm) — single latent per token (MQA) - - Compressor (CSA ratio=4 overlapping, HCA ratio=128 non-overlapping) - with wkv, wgate, ape, norm - - Indexer (CSA only) — wq_b + weights_proj + compressor - - Partial RoPE (last 64 dims, GPT-J interleaved, YaRN factor=16) + inverse RoPE - - Attention sinks (per-head logit bias, paper §2.3.3) - - SDPA for short seq, FMHA for long - - Grouped output projection (wo_a BMM + wo_b NVFP4) - - Routed MoE (384 experts, top-6, hash + dense routing, SwiGLU clamp) - - Shared expert (NVFP4 gate/up/down) - - RMSNorm (pre-norm before each sub-block) - - KV cache: SWA ring buffer + compressed entries - - FP8 E4M3 quant on non-RoPE KV dims (paper §2.3.4) + - mHC (Sinkhorn-Knopp, B_l transposed, [pre,post,comb] ordering) + - Low-rank Q: q_a_proj → q_a_norm → q_b_proj → q_b_norm + - KV: kv_proj → kv_norm — single latent per token (MQA) + - Compressor: CSA (ratio=4, Ca/Cb overlapping) and HCA (ratio=128) + - Indexer: CSA top-k with its own compressor at index_head_dim + - Partial RoPE (last 64 dims, GPT-J interleaved, YaRN factor=16) + inverse + - Attention sinks (per-head logit bias) + - Full attention: [compressed_kv, swa_kv] concatenated + - Grouped output projection: wo_a (BF16 BMM) + wo_b (NVFP4) + - MoE: 384 experts, top-6, hash (layers 0-2) + noaux_tc (3+), SwiGLU clamp + - Shared expert (NVFP4) + - NVFP4 two-level scale: weight_scale (E4M3) × weight_scale_2 (scalar) × input_scale (scalar) -Checkpoint: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4 - Key prefix: layers.{li}.attn.* (NOT model.layers.{li}.self_attn.*) - NVFP4 weights: .weight (uint8) + .scale (E4M3) - BF16 weights: compressor.norm, q_norm, kv_norm, attn_norm, etc. - -Usage (on B200): - source /root/dsv4-nvfp4-workspace/venv/bin/activate - cd /root/dsv4-nvfp4-workspace/kernel - python3 single_shot_inference.py +Checkpoint key format: + model.layers.{li}.self_attn.{kv_proj, q_a_proj, q_b_proj, o_a_proj, o_b_proj}.{weight, weight_scale, ...} + model.layers.{li}.self_attn.compressor.{kv_proj, gate_proj}.{weight, weight_scale, ...} + model.layers.{li}.self_attn.compressor.position_bias (BF16) + model.layers.{li}.self_attn.compressor.kv_norm.weight (BF16) + model.layers.{li}.self_attn.compressor.indexer.* + model.layers.{li}.self_attn.sinks (BF16) + model.layers.{li}.attn_hc.{fn, base, scale} + model.layers.{li}.ffn_hc.{fn, base, scale} + model.layers.{li}.input_layernorm.weight (BF16) + model.layers.{li}.post_attention_layernorm.weight (BF16) + model.layers.{li}.mlp.experts.{eid}.{gate_proj,up_proj,down_proj}.{weight, weight_scale, ...} + model.layers.{li}.mlp.shared_experts.{gate_proj,up_proj,down_proj}.{weight, weight_scale, ...} + model.layers.{li}.mlp.gate.{weight, e_score_correction_bias, tid2eid} + model.embed_tokens.weight, model.norm.weight, lm_head.weight + model.hc_head.{hc_fn, hc_base, hc_scale} """ import os, sys, time, json, math, argparse import torch @@ -43,12 +47,12 @@ from pathlib import Path # ===================================================================== # Configuration # ===================================================================== - def parse_args(): - p = argparse.ArgumentParser(description='DSV4 Single-Shot Inference') + p = argparse.ArgumentParser() p.add_argument('--max-tokens', type=int, default=8192) p.add_argument('--prompt', type=str, default=None) p.add_argument('--seed', type=int, default=42) + p.add_argument('--verbose', type=int, default=1) return p.parse_args() _args = parse_args() @@ -57,25 +61,20 @@ MAX_NEW_TOKENS = _args.max_tokens PROMPT = _args.prompt or "The capital of France is" NUM_GPUS = 8 SEED = _args.seed +VERBOSE = _args.verbose +GROWTH_DIAG = VERBOSE >= 1 -# Thinking is ALWAYS ON — this is a reasoning model -THINK_START = 128821 # fi -THINK_END = 128822 # fl -USER_TOKEN = 128803 -ASSISTANT_TOKEN = 128804 - -GROWTH_DIAG = True +THINK_START, THINK_END = 128821, 128822 +USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804 # ===================================================================== -# NVFP4 dequantization — native checkpoint format +# NVFP4 dequantization — two-level scale # ===================================================================== - FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) -def dequant_nvfp4(weight, scale): - """Dequantize NVFP4 weight→BF16. weight: (O, I//2) uint8, scale: (O, I//16) E4M3.""" - O = weight.shape[0] - I2 = weight.shape[1] +def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None): + """Dequantize NVFP4 → BF16. weight: (O,I//2) uint8, scale: (O,I//16) E4M3.""" + O, I2 = weight.shape I = I2 * 2 lo = (weight & 0x0F).to(torch.int8) hi = (weight >> 4).to(torch.int8) @@ -83,32 +82,41 @@ def dequant_nvfp4(weight, scale): lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.) hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.) w = torch.stack([lo_f, hi_f], -1).reshape(O, I) - s = scale.float().repeat_interleave(16, 1) + s = weight_scale.float().repeat_interleave(16, 1) + if weight_scale_2 is not None: s = s * weight_scale_2.float() + if input_scale is not None: s = s * input_scale.float() return (w * s).bfloat16() -def nvfp4_linear(x, weight, scale): - return F.linear(x, dequant_nvfp4(weight, scale)) +def nvfp4_linear(x, weight, weight_scale, weight_scale_2=None, input_scale=None): + return F.linear(x, dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale)) + +def get_nvfp4_weight(w, pfx, proj_name): + k = f"{pfx}.{proj_name}" + return (w.get(f"{k}.weight"), w.get(f"{k}.weight_scale"), + w.get(f"{k}.weight_scale_2"), w.get(f"{k}.input_scale")) + +def do_nvfp4_linear(x, w, pfx, proj_name): + weight, ws, ws2, isc = get_nvfp4_weight(w, pfx, proj_name) + if weight is None: return None + d = x.device + return nvfp4_linear(x, weight.to(d), ws.to(d), + ws2.to(d) if ws2 is not None else None, + isc.to(d) if isc is not None else None) # ===================================================================== # RMSNorm # ===================================================================== - def rmsnorm(x, weight, eps=1e-6): - """x: (T, H) BF16 → (T, H) BF16""" xf = x.float() - inv = xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() - return (xf * inv * weight).bfloat16() + return (xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() * weight.float()).bfloat16() def unweighted_rmsnorm(x, eps=1e-6): - """x: (..., H) → (..., H) — no learnable weight, returns FP32.""" xf = x.float() - inv = xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() - return xf * inv + return xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() # ===================================================================== -# mHC — Manifold-Constrained Hyper-Connections +# mHC # ===================================================================== - HC_EPS = 1e-6 def sinkhorn_knopp(logits, t_max=20, eps=HC_EPS): @@ -121,23 +129,18 @@ def sinkhorn_knopp(logits, t_max=20, eps=HC_EPS): class mHCBlock: def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'): - self.d = hidden_dim - self.n_hc = n_hc - self.K = n_hc * hidden_dim - self.t_max = sinkhorn_iters - self.device = device + self.d, self.n_hc, self.K = hidden_dim, n_hc, n_hc * hidden_dim + self.t_max, self.device = sinkhorn_iters, device def load(self, fn, base, scale): - n = self.n_hc; dev = self.device - self.W_pre = fn[0:n].to(dev, torch.float32).contiguous() - self.W_post = fn[n:2*n].to(dev, torch.float32).contiguous() - self.W_comb = fn[2*n:].to(dev, torch.float32).contiguous() - self.S_pre = base[0:n].reshape(1,n).to(dev, torch.bfloat16).contiguous() - self.S_post = base[n:2*n].reshape(n,1).to(dev, torch.bfloat16).contiguous() - self.S_comb = base[2*n:].reshape(n,n).to(dev, torch.bfloat16).contiguous() - self.alpha_pre = scale[0].item() - self.alpha_post = scale[1].item() - self.alpha_comb = scale[2].item() + n = self.n_hc + self.W_pre = fn[0:n].contiguous() + self.W_post = fn[n:2*n].contiguous() + self.W_comb = fn[2*n:].contiguous() + self.S_pre = base[0:n].reshape(1, n).float() + self.S_post = base[n:2*n].reshape(n, 1).float() + self.S_comb = base[2*n:].reshape(n, n).float() + self.alpha_pre, self.alpha_post, self.alpha_comb = scale[0].item(), scale[1].item(), scale[2].item() @staticmethod def init_state(emb, n_hc=4): @@ -145,515 +148,336 @@ class mHCBlock: def pre_block(self, X): T, n, d = X.shape - Xf = X.reshape(T, self.K).bfloat16() - Xn = unweighted_rmsnorm(Xf) # (T, K) FP32 + Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16()) W = torch.cat([self.W_pre, self.W_post, self.W_comb]) - proj = Xn @ W.T # (T, 24) FP32 - pre_r = proj[:, :n]; post_r = proj[:, n:2*n]; comb_r = proj[:, 2*n:2*n+n*n] - pre_t = self.alpha_pre * pre_r + self.S_pre.float().flatten().unsqueeze(0) - post_t = self.alpha_post * post_r + self.S_post.float().flatten().unsqueeze(0) - comb_t = self.alpha_comb * comb_r + self.S_comb.float().flatten().unsqueeze(0) + proj = Xn @ W.T + pre_t = self.alpha_pre * proj[:, :n] + self.S_pre.flatten().unsqueeze(0) + post_t = self.alpha_post * proj[:, n:2*n] + self.S_post.flatten().unsqueeze(0) + comb_t = self.alpha_comb * proj[:, 2*n:2*n+n*n] + self.S_comb.flatten().unsqueeze(0) A = torch.sigmoid(pre_t) + HC_EPS C = 2.0 * torch.sigmoid(post_t) - B = sinkhorn_knopp(comb_t.reshape(T,n,n), t_max=self.t_max) + B = sinkhorn_knopp(comb_t.reshape(T, n, n), t_max=self.t_max) x_in = torch.bmm(A.unsqueeze(1), X).squeeze(1).bfloat16() return x_in, {'B': B, 'C': C} def post_block(self, X, F_out, ctx): - BX = torch.bmm(ctx['B'].transpose(-1,-2), X.float()) + BX = torch.bmm(ctx['B'].transpose(-1, -2), X.float()) CF = ctx['C'].unsqueeze(-1) * F_out.unsqueeze(1) return (CF.float() + BX).bfloat16() # ===================================================================== # HcHead # ===================================================================== - class HcHead: def __init__(self, hidden_dim=7168, n_hc=4, device='cuda:0'): - self.K = n_hc * hidden_dim; self.device = device; self.n_hc = n_hc + self.K, self.device, self.n_hc = n_hc * hidden_dim, device, n_hc def load(self, fn, base, scale=None): - dev = self.device - self.fn = fn.to(dev, torch.float32).contiguous() - self.base = base.to(dev, torch.bfloat16).contiguous() - self.scale = scale.to(dev, torch.float32).contiguous() if scale is not None else torch.tensor(1., device=dev) + self.fn = fn.to(self.device, torch.float32).contiguous() + self.base = base.to(self.device, torch.float32).contiguous() + self.scale = scale.to(self.device, torch.float32).item() if scale is not None else 1.0 def forward(self, X): T = X.shape[0] - Xf = X.reshape(T, self.K).bfloat16() - Xn = unweighted_rmsnorm(Xf) - mix = F.linear(Xn, self.fn).float() - pre = torch.sigmoid(mix * self.scale + self.base.float().unsqueeze(0)) + HC_EPS + Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16()) + mix = F.linear(Xn, self.fn[:self.n_hc]).float() + pre = torch.sigmoid(mix * self.scale + self.base[:self.n_hc].unsqueeze(0)) + HC_EPS return (pre.unsqueeze(-1) * X.float()).sum(1).bfloat16() # ===================================================================== -# RoPE — partial GPT-J interleaved, YaRN scaling +# RoPE # ===================================================================== - def build_rope_cache(max_pos, rope_dim, device, theta=10000., rope_type="default", rope_factor=1., orig_max=4096, beta_fast=32, beta_slow=1): - half = rope_dim // 2 freqs = 1. / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) if rope_type == "yarn" and rope_factor > 1.: - low_wl = orig_max / (beta_fast * 2.) - high_wl = orig_max / (beta_slow * 2.) nf = [] for f in freqs: - wl = 2*math.pi/f - if wl < low_wl: nf.append(f) - elif wl > high_wl: nf.append(f/rope_factor) + wl = 2 * math.pi / f + lo, hi = orig_max / (beta_fast * 2.), orig_max / (beta_slow * 2.) + if wl < lo: nf.append(f) + elif wl > hi: nf.append(f / rope_factor) else: - sm = (orig_max/(wl*beta_slow) - rope_factor)/(rope_factor*(beta_fast/beta_slow - 1)) - nf.append((1-sm)*f/rope_factor + sm*f) + sm = (orig_max / (wl * beta_slow) - rope_factor) / (rope_factor * (beta_fast / beta_slow - 1)) + nf.append((1 - sm) * f / rope_factor + sm * f) freqs = torch.tensor(nf, dtype=torch.float32) angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) return torch.cos(angles).to(device), torch.sin(angles).to(device) def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False): - """Apply/inverse partial RoPE. x: (T, n_h, hd), pos: (T,). FP32 arithmetic.""" - T, nh, hd = x.shape; nope = hd - rope_dim - c = cos[pos].unsqueeze(1); s = sin[pos].unsqueeze(1) # (T,1,half) - xr = x[:,:,nope:].float() - ev, od = xr[...,0::2], xr[...,1::2] - if inverse: - rev, rod = ev*c + od*s, -ev*s + od*c - else: - rev, rod = ev*c - od*s, ev*s + od*c - out = x.clone(); ro = torch.empty_like(xr) - ro[...,0::2] = rev; ro[...,1::2] = rod - out[:,:,nope:] = ro.bfloat16() + T, nh, hd = x.shape + nope = hd - rope_dim + c, s = cos[pos].unsqueeze(1), sin[pos].unsqueeze(1) + xr = x[:, :, nope:].float() + ev, od = xr[..., 0::2], xr[..., 1::2] + if inverse: rev, rod = ev*c + od*s, -ev*s + od*c + else: rev, rod = ev*c - od*s, ev*s + od*c + out = x.clone() + ro = torch.empty_like(xr) + ro[..., 0::2], ro[..., 1::2] = rev, rod + out[:, :, nope:] = ro.bfloat16() return out -# ===================================================================== -# FP8 E4M3 quant (paper §2.3.4 — non-RoPE dims stored as FP8) -# ===================================================================== - -def quant_fp8_e4m3(x, max_val=448.0): - """Quantize BF16 tensor to FP8 E4M3. Returns (quantized, inv_scale).""" - amax = x.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) - inv_scale = amax / max_val # scale such that x / scale fits in [-448, 448] - scale = 1.0 / inv_scale.clamp(min=1e-30) - x_q = (x.float() * scale).clamp(-448., 448.) - return x_q.bfloat16(), inv_scale # We store dequant-ready values - -def dequant_fp8(x_q, inv_scale): - """Dequantize FP8-scaled values back to BF16.""" - return (x_q.float() / inv_scale.clamp(min=1e-30)).bfloat16() - # ===================================================================== # Compressor — CSA (ratio=4) and HCA (ratio=128) # ===================================================================== - class Compressor: - """Token-level softmax compression of KV (paper §2.3). - - CSA (ratio=4): overlapping blocks, dual a/b streams. - HCA (ratio=128): non-overlapping, single stream. - """ - def __init__(self, ratio, head_dim, H, device): - self.ratio = ratio - self.hd = head_dim - self.H = H - self.device = device - # Weights set via load() - self.wkv = None; self.wkv_s = None - self.wgate = None; self.wgate_s = None - self.ape = None; self.norm_w = None - # State for overlapping CSA compression - self.prev_kv = None; self.prev_score = None + def __init__(self, ratio, head_dim, hidden_size, device): + self.ratio, self.hd, self.H, self.device = ratio, head_dim, hidden_size, device + self.is_csa = (ratio == 4) + self.kv_dim = 2 * head_dim if self.is_csa else head_dim + self.wkv_w = self.wkv_ws = self.wkv_ws2 = self.wkv_isc = None + self.wgate_w = self.wgate_ws = self.wgate_ws2 = self.wgate_isc = None + self.ape = None + self.kv_norm_w = None def load(self, w, pfx): - d = self.device - # Compressor wkv/wgate are BF16 (NOT NVFP4 — no .scale in checkpoint) - if f"{pfx}.wkv.weight" in w: - self.wkv = w[f"{pfx}.wkv.weight"] # BF16 weight, use F.linear directly - self.wkv_s = None # No NVFP4 scale - if f"{pfx}.wgate.weight" in w: - self.wgate = w[f"{pfx}.wgate.weight"] # BF16 weight - self.wgate_s = None - if f"{pfx}.ape" in w: - self.ape = w[f"{pfx}.ape"].to(d) - if f"{pfx}.norm.weight" in w: - self.norm_w = w[f"{pfx}.norm.weight"].to(d, torch.float32) - - def reset_state(self): - self.prev_kv = None; self.prev_score = None + self.wkv_w, self.wkv_ws, self.wkv_ws2, self.wkv_isc = get_nvfp4_weight(w, pfx, 'kv_proj') + self.wgate_w, self.wgate_ws, self.wgate_ws2, self.wgate_isc = get_nvfp4_weight(w, pfx, 'gate_proj') + self.ape = w.get(f"{pfx}.position_bias") + self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight") def forward(self, hidden_states, positions): - """Compress hidden states into compressed KV entries. - - h: (T, H) BF16 — post-RMSNorm - positions: (T,) int64 - - Returns: compressed_kv (N, hd) BF16, compressed_pos (N,) int64 - """ - if self.ratio == 0 or self.wkv is None: - return None, None - + """Returns (compressed_kv (N,hd) or None, comp_positions (N,) or None, block_bias or None).""" + if self.ratio == 0 or self.wkv_w is None: + return None, None, None T = hidden_states.shape[0] r = self.ratio + dev = hidden_states.device + n_complete = T // r + if n_complete == 0: + return None, None, None - # Project to KV and scores (BF16 weights, NOT NVFP4) - kv = F.linear(hidden_states, self.wkv.bfloat16()) - score = F.linear(hidden_states, self.wgate.bfloat16()) + # Project + kv = nvfp4_linear(hidden_states, self.wkv_w.to(dev), self.wkv_ws.to(dev), + self.wkv_ws2.to(dev) if self.wkv_ws2 is not None else None, + self.wkv_isc.to(dev) if self.wkv_isc is not None else None) + gate = nvfp4_linear(hidden_states, self.wgate_w.to(dev), self.wgate_ws.to(dev), + self.wgate_ws2.to(dev) if self.wgate_ws2 is not None else None, + self.wgate_isc.to(dev) if self.wgate_isc is not None else None) - # Add absolute position encoding + # Add position bias (cyclic per block) if self.ape is not None: - if self.ape.dim() == 1: - score = score + self.ape[positions].unsqueeze(-1).to(score.dtype) - else: - score = score + self.ape[positions].to(score.dtype) + ape = self.ape.to(dev) + n_full = T // r + for bi in range(n_full): + s, e = bi * r, (bi + 1) * r + kv[s:e] += ape.to(kv.dtype) + gate[s:e] += ape.to(gate.dtype) + rem = T % r + if rem > 0: + s = n_full * r + kv[s:] += ape[:rem].to(kv.dtype) + gate[s:] += ape[:rem].to(gate.dtype) - # The reference uses coff (compression output features) = ratio - # wkv output: (T, 2 * coff * hd) where 2 is for a/b streams (CSA) - # For HCA: (T, coff * hd) — single stream - # - # CSA (ratio=4): kv = (T, 8*hd), split into a-stream (4*hd) and b-stream (4*hd) - # HCA (ratio=128): kv = (T, 128*hd), single stream - # - # Overlapping CSA: block i uses tokens from previous block + current block - # a-stream = softmax(score_a[:4]) * kv_a[:4] (current block only) - # b-stream = softmax(score_b[:4]) * kv_b[:4] (previous block only) - # Final: concat(a_compressed, b_compressed) → (2*coff*hd) → norm → RoPE + T_comp = n_complete * r + comp_list, comp_pos_list = [], [] - if r == 4: - # CSA: dual a/b streams, overlapping - # Split kv and score into a/b halves - half = kv.shape[-1] // 2 - kv_a, kv_b = kv[:, :half], kv[:, half:] - sc_a, sc_b = score[:, :half], score[:, half:] - - kv_a = kv_a.reshape(T, r, self.hd) # (T, 4, hd) - kv_b = kv_b.reshape(T, r, self.hd) - sc_a = sc_a.reshape(T, r, self.hd) - sc_b = sc_b.reshape(T, r, self.hd) - - n_complete = T // r - if n_complete == 0: - # Not enough tokens for even one compressed entry - # Save state for next call - self.prev_kv = kv; self.prev_score = score - return None, None - - T_comp = n_complete * r - # Compress each block - comp_list = [] - comp_pos_list = [] + if self.is_csa: + # Overlapping Ca/Cb: split kv and gate into Ca (first hd) and Cb (second hd) + Ca = kv[:T_comp, :self.hd].reshape(n_complete, r, self.hd) + Cb = kv[:T_comp, self.hd:].reshape(n_complete, r, self.hd) + Ga = gate[:T_comp, :self.hd].reshape(n_complete, r, self.hd) + Gb = gate[:T_comp, self.hd:].reshape(n_complete, r, self.hd) for bi in range(n_complete): - start = bi * r - end = start + r - - # a-stream: softmax over current block's a-KV - a_kv = kv_a[start:end] # (4, hd) - a_sc = sc_a[start:end] # (4, hd) - a_probs = torch.softmax(a_sc.float(), dim=0) # (4, hd) - a_comp = (a_probs * a_kv.float()).sum(0) # (hd,) - - # b-stream: softmax over PREVIOUS block's b-KV if bi > 0: - b_kv = kv_b[start-r:end-r] # previous block - b_sc = sc_b[start-r:end-r] - b_probs = torch.softmax(b_sc.float(), dim=0) - b_comp = (b_probs * b_kv.float()).sum(0) + block_kv = torch.cat([Ca[bi-1], Cb[bi]], dim=0) # (2r, hd) + block_gate = torch.cat([Ga[bi-1], Gb[bi]], dim=0) else: - # First block: no previous → zero b-stream - b_comp = torch.zeros(self.hd, device=kv.device, dtype=torch.float32) - - # Concatenate a and b compressed - comp = torch.cat([a_comp, b_comp]) # (2*hd,) - - # RMSNorm - if self.norm_w is not None: - nw = self.norm_w - # norm_w is (2*hd,) — covers both streams - inv = comp.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - comp = comp * inv * nw - - comp_list.append(comp.bfloat16()) - comp_pos_list.append(positions[end - 1]) - - compressed = torch.stack(comp_list) # (N, 2*hd) BF16 - comp_positions = torch.stack(comp_pos_list) - return compressed, comp_positions - + block_kv = Cb[bi] # (r, hd) — no previous Ca + block_gate = Gb[bi] + probs = torch.softmax(block_gate.float(), dim=0) + compressed = (probs * block_kv.float()).sum(0) + if self.kv_norm_w is not None: + nw = self.kv_norm_w.to(dev).float() + compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw + comp_list.append(compressed.bfloat16()) + comp_pos_list.append(positions[(bi+1)*r - 1]) else: - # HCA (ratio=128): non-overlapping, single stream - kv_r = kv.reshape(T, r, self.hd) # (T, 128, hd) - sc_r = score.reshape(T, r, self.hd) - n_complete = T // r - if n_complete == 0: - return None, None - - T_comp = n_complete * r - kv_blocks = kv_r[:T_comp].reshape(n_complete, r, self.hd) - sc_blocks = sc_r[:T_comp].reshape(n_complete, r, self.hd) - - probs = torch.softmax(sc_blocks.float(), dim=1) - compressed = (probs * kv_blocks.float()).sum(1) # (N, hd) - - if self.norm_w is not None: - inv = compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - compressed = compressed * inv * self.norm_w.unsqueeze(0) - - comp_positions = positions[:T_comp].reshape(n_complete, r)[:, -1] - return compressed.bfloat16(), comp_positions + # HCA: non-overlapping, single stream + kv_blocks = kv[:T_comp].reshape(n_complete, r, self.hd) + gate_blocks = gate[:T_comp].reshape(n_complete, r, self.hd) + for bi in range(n_complete): + probs = torch.softmax(gate_blocks[bi].float(), dim=0) + compressed = (probs * kv_blocks[bi].float()).sum(0) + if self.kv_norm_w is not None: + nw = self.kv_norm_w.to(dev).float() + compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw + comp_list.append(compressed.bfloat16()) + comp_pos_list.append(positions[(bi+1)*r - 1]) + compressed_kv = torch.stack(comp_list) + comp_positions = torch.stack(comp_pos_list) + # block_bias: causal mask for compressed entries + N = len(comp_list) + block_bias = torch.zeros(1, T, N, dtype=torch.float32, device=dev) + return compressed_kv, comp_positions, block_bias # ===================================================================== -# Indexer — CSA top-k selection +# Indexer — CSA top-k # ===================================================================== - class Indexer: def __init__(self, n_ih, ihd, top_k, device): - self.n_ih = n_ih; self.ihd = ihd - self.top_k = top_k; self.device = device - self.wq_b = None; self.wq_b_s = None - self.weights_proj = None; self.compressor = None + self.n_ih, self.ihd, self.top_k, self.device = n_ih, ihd, top_k, device + self.q_b_w = self.q_b_ws = self.q_b_ws2 = self.q_b_isc = None + self.wp_w = self.wp_ws = self.wp_ws2 = self.wp_isc = None + self.compressor = None def load(self, w, pfx): - d = self.device - if f"{pfx}.wq_b.weight" in w: - self.wq_b = w[f"{pfx}.wq_b.weight"]; self.wq_b_s = w[f"{pfx}.wq_b.scale"] - # weights_proj is BF16 (not NVFP4) - if f"{pfx}.weights_proj.weight" in w: - self.weights_proj = w[f"{pfx}.weights_proj.weight"].to(d) - # Indexer compressor (BF16 wkv/wgate, no NVFP4 scale) - if f"{pfx}.compressor.wkv.weight" in w: - self.compressor = Compressor(4, self.ihd, 7168, d) + self.q_b_w, self.q_b_ws, self.q_b_ws2, self.q_b_isc = get_nvfp4_weight(w, pfx, 'q_b_proj') + self.wp_w, self.wp_ws, self.wp_ws2, self.wp_isc = get_nvfp4_weight(w, pfx, 'weights_proj') + if f"{pfx}.compressor.kv_proj.weight" in w: + self.compressor = Compressor(4, self.ihd, 7168, self.device) self.compressor.load(w, f"{pfx}.compressor") def forward(self, q_lora, hidden_states, comp_indexer_kv, positions): - """Score and select top-k compressed blocks.""" - if self.wq_b is None or comp_indexer_kv is None: + if self.q_b_w is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0: return None + dev = q_lora.device T = q_lora.shape[0] n_comp = comp_indexer_kv.shape[0] - if n_comp == 0: - return None - - q_idx = nvfp4_linear(q_lora, self.wq_b, self.wq_b_s) # (T, n_ih*ihd) + q_idx = nvfp4_linear(q_lora, self.q_b_w.to(dev), self.q_b_ws.to(dev), + self.q_b_ws2.to(dev) if self.q_b_ws2 is not None else None, + self.q_b_isc.to(dev) if self.q_b_isc is not None else None) q_idx = q_idx.reshape(T, self.n_ih, self.ihd) - w_h = F.linear(hidden_states, self.weights_proj.bfloat16()) # (T, n_ih) - + w_h = nvfp4_linear(hidden_states, self.wp_w.to(dev), self.wp_ws.to(dev), + self.wp_ws2.to(dev) if self.wp_ws2 is not None else None, + self.wp_isc.to(dev) if self.wp_isc is not None else None) k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd) scores = torch.einsum('tnd,cnd->tnc', q_idx.float(), k_idx.float()) scores = F.relu(scores) - total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp) - + total = (scores * w_h.unsqueeze(-1).float()).sum(1) tk = min(self.top_k, n_comp) _, idx = total.topk(tk, -1) return idx - # ===================================================================== -# KV Cache — SWA ring buffer + compressed entries +# KV Cache # ===================================================================== - class KVCache: - def __init__(self, head_dim, window_size=128, max_comp=8192, device='cuda:0'): - self.hd = head_dim; self.ws = window_size; self.dev = device - # SWA ring buffer: stores RoPE'd KV for the sliding window + def __init__(self, head_dim, window_size=128, device='cuda:0'): + self.hd, self.ws, self.dev = head_dim, window_size, device self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device) self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device) - self.swa_len = 0; self.swa_head = 0 - # Compressed KV (from compressor, already normed, needs RoPE) - self.comp_kv = None; self.comp_pos = None; self.n_comp = 0 - # Indexer compressed keys (CSA only) + self.swa_len, self.swa_head = 0, 0 + self.comp_kv, self.comp_pos, self.n_comp = None, None, 0 self.comp_idx_kv = None def append_swa(self, kv, pos): - """kv: (T, hd) BF16 — RoPE'd KV. pos: (T,) int64.""" T = kv.shape[0] for i in range(T): idx = (self.swa_head + i) % self.ws - self.swa[idx] = kv[i]; self.swa_pos[idx] = pos[i] + self.swa[idx], self.swa_pos[idx] = kv[i], pos[i] self.swa_head = (self.swa_head + T) % self.ws self.swa_len = min(self.swa_len + T, self.ws) def add_compressed(self, ckv, cpos, idx_kv=None): - """Add compressed entries. ckv: (N, hd) BF16, cpos: (N,) int64.""" if ckv is None: return - if self.comp_kv is None: - self.comp_kv = ckv; self.comp_pos = cpos - else: - self.comp_kv = torch.cat([self.comp_kv, ckv]) - self.comp_pos = torch.cat([self.comp_pos, cpos]) + self.comp_kv = ckv if self.comp_kv is None else torch.cat([self.comp_kv, ckv]) + self.comp_pos = cpos if self.comp_pos is None else torch.cat([self.comp_pos, cpos]) self.n_comp = self.comp_kv.shape[0] if idx_kv is not None: - if self.comp_idx_kv is None: - self.comp_idx_kv = idx_kv - else: - self.comp_idx_kv = torch.cat([self.comp_idx_kv, idx_kv]) + self.comp_idx_kv = idx_kv if self.comp_idx_kv is None else torch.cat([self.comp_idx_kv, idx_kv]) def get_swa(self): - """Get SWA KV and positions. Returns (seq, hd) BF16, (seq,) int64.""" if self.swa_len == 0: return torch.zeros(0, self.hd, device=self.dev, dtype=torch.bfloat16), \ torch.zeros(0, device=self.dev, dtype=torch.long) if self.swa_len < self.ws: - return self.swa[:self.swa_len], self.swa_pos[:self.swa_len] - # Ring buffer: head..end, start..head + return self.swa[:self.swa_len].clone(), self.swa_pos[:self.swa_len].clone() idx = torch.arange(self.swa_head, self.swa_head + self.ws) % self.ws - return self.swa[idx], self.swa_pos[idx] - - def get_compressed(self): - return self.comp_kv, self.comp_pos - + return self.swa[idx].clone(), self.swa_pos[idx].clone() # ===================================================================== # Weight loading # ===================================================================== - def load_weights(checkpoint_dir): - """Load all weights from checkpoint to CPU.""" from safetensors.torch import load_file cdir = Path(checkpoint_dir) - idx = cdir / "model.safetensors.index.json" wmap = {} + idx = cdir / "model.safetensors.index.json" if idx.exists(): with open(idx) as f: wmap = json.load(f).get("weight_map", {}) - shard_names = set(wmap.values()) if wmap else {f"model-{i:05d}-of-00095.safetensors" for i in range(1,96)} + shards = set(wmap.values()) if wmap else set() all_w = {} - for sn in sorted(shard_names): - if not (cdir / sn).exists(): continue - all_w.update(load_file(str(cdir / sn))) - print(f"Loaded {len(all_w)} tensors from {len(shard_names)} shards") + for sn in sorted(shards): + if (cdir / sn).exists(): + all_w.update(load_file(str(cdir / sn))) + print(f"Loaded {len(all_w)} tensors from {len(shards)} shards") return all_w - def cache_layer_weights(all_w, n_layers, devices): - """Pre-load layer weights to GPUs.""" cached = {} for li in range(n_layers): dev = devices[li % len(devices)] - pfx = f"layers.{li}." - w = {} - for k, v in all_w.items(): - if k.startswith(pfx): - w[k] = v.to(device=dev, non_blocking=True) + pfx = f"model.layers.{li}." + w = {k: v.to(device=dev, non_blocking=True) for k, v in all_w.items() if k.startswith(pfx)} cached[li] = w if (li+1) % 10 == 0: print(f" Cached {li+1}/{n_layers} layers") return cached - # ===================================================================== # Attention forward # ===================================================================== - def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, kv_cache, positions, compressor, indexer): - """Full attention sub-block forward. - - x_normed: (T, H) BF16 — post-RMSNorm input - w: weight dict for this layer - Returns: F_attn (T, H) BF16 - """ dev = x_normed.device T = x_normed.shape[0] n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg.get("qk_rope_head_dim", 64) - q_lora_rank = cfg.get("q_lora_rank", 1536) - o_groups = cfg.get("num_output_groups", 16) - o_rank = cfg.get("output_group_dim", 1024) - compress_ratio = cfg.get("compress_ratios", [128]*61)[li] if li < len(cfg.get("compress_ratios", [])) else 128 + o_groups = cfg.get("o_groups", 16) + o_rank = cfg.get("o_lora_rank", 1024) + ratio = compressor.ratio if compressor is not None else 0 scale = 1.0 / math.sqrt(hd) + pfx = f"model.layers.{li}.self_attn" - pfx = f"layers.{li}.attn" - - # 1. Fused Q-down + KV projection (separate in checkpoint) - # wq_a: (q_lora_rank, H) → Q down-projection - # wkv: (head_dim, H) → KV projection - q_a = nvfp4_linear(x_normed, w[f"{pfx}.wq_a.weight"], w[f"{pfx}.wq_a.scale"]) # (T, q_lora_rank) - kv = nvfp4_linear(x_normed, w[f"{pfx}.wkv.weight"], w[f"{pfx}.wkv.scale"]) # (T, hd) - - # 2. Q norm (RMSNorm after q_a, before q_b) - q_norm_w = w.get(f"{pfx}.q_norm.weight") - if q_norm_w is not None: - q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) - - # 3. Q up-projection - q = nvfp4_linear(q_a, w[f"{pfx}.wq_b.weight"], w[f"{pfx}.wq_b.scale"]) # (T, n_h*hd) - - # 4. q_b_norm (unweighted RMSNorm) + # 1. Q projection: q_a → q_a_norm → q_b → q_b_norm + q_a = do_nvfp4_linear(x_normed, w, pfx, 'q_a_proj') + q_norm_w = w.get(f"{pfx}.q_a_norm.weight") + if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) + q = do_nvfp4_linear(q_a, w, pfx, 'q_b_proj') q = unweighted_rmsnorm(q).bfloat16() - - # 5. KV norm - kv_norm_w = w.get(f"{pfx}.kv_norm.weight") - if kv_norm_w is not None: - kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) - - # 6. Reshape Q - q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd) - - # 7. Apply RoPE to Q + q_heads = q.reshape(T, n_h, hd) q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd) - # 8. Apply RoPE to KV - kv_new = kv.reshape(T, 1, hd) - kv_new = _apply_rope(kv_new, positions, rope_cos, rope_sin, rd) - kv_new = kv_new.reshape(T, hd) # (T, hd) + # 2. KV projection (MQA, single KV head, hd dim) + kv = do_nvfp4_linear(x_normed, w, pfx, 'kv_proj') + kv_norm_w = w.get(f"{pfx}.kv_norm.weight") + if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) + kv_3d = kv.reshape(T, 1, hd) + kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd) + kv_roped = kv_3d.reshape(T, hd) + kv_cache.append_swa(kv_roped, positions) - # 9. Append to SWA cache - kv_cache.append_swa(kv_new, positions) - - # 10. Run compressor (CSA/HCA) - comp_kv, comp_pos = None, None + # 3. Compressor → compressed KV (dim = hd) + comp_kv, comp_pos, block_bias = None, None, None comp_idx_kv = None if compressor is not None and compressor.ratio > 0: - comp_kv, comp_pos = compressor.forward(x_normed, positions) - - # Apply RoPE to compressed KV + comp_kv, comp_pos, block_bias = compressor.forward(x_normed, positions) if comp_kv is not None: - # comp_kv shape depends on ratio: - # CSA (4): (N, 2*hd) — a and b streams - # HCA (128): (N, hd) — single stream - if compress_ratio == 4: - # Split into a and b, RoPE each, concat back - c_a = comp_kv[:, :hd].reshape(comp_kv.shape[0], 1, hd) - c_b = comp_kv[:, hd:].reshape(comp_kv.shape[0], 1, hd) - # Use compressed positions for RoPE - c_a = _apply_rope(c_a, comp_pos, rope_cos, rope_sin, rd).reshape(-1, hd) - c_b = _apply_rope(c_b, comp_pos, rope_cos, rope_sin, rd).reshape(-1, hd) - comp_kv = torch.cat([c_a, c_b], dim=-1) # (N, 2*hd) - else: - comp_kv_3d = comp_kv.reshape(-1, 1, hd) - comp_kv_3d = _apply_rope(comp_kv_3d, comp_pos, rope_cos, rope_sin, rd) - comp_kv = comp_kv_3d.reshape(-1, hd) - - # Run indexer compressor for CSA - if compressor.ratio == 4 and indexer is not None and indexer.compressor is not None: - comp_idx_kv, _ = indexer.compressor.forward(x_normed, positions) - else: - comp_idx_kv = None - - # Add to cache + comp_kv_3d = comp_kv.unsqueeze(1) + comp_kv_3d = _apply_rope(comp_kv_3d, comp_pos, rope_cos, rope_sin, rd) + comp_kv = comp_kv_3d.squeeze(1) + if compressor.is_csa and indexer is not None and indexer.compressor is not None: + comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions) kv_cache.add_compressed(comp_kv, comp_pos, comp_idx_kv) - # 11. Run indexer (CSA only) + # 4. Indexer top-k (CSA only) topk_idx = None - if indexer is not None and compressor is not None and compressor.ratio == 4: + if indexer is not None and ratio == 4: topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions) - # 12. Gather KV for attention: SWA + compressed (top-k for CSA, all for HCA) - swa_kv, swa_pos = kv_cache.get_swa() # (swa_len, hd) BF16 + # 5. Gather full KV: [compressed, swa] + swa_kv, swa_pos = kv_cache.get_swa() swa_len = swa_kv.shape[0] - - # Build full KV sequence for attention - ratio = compressor.ratio if compressor is not None else 0 if kv_cache.comp_kv is not None and kv_cache.n_comp > 0: if ratio == 4 and topk_idx is not None: - # CSA: use top-k compressed entries + SWA - # topk_idx: (T, top_k) int64 - # For T=1 decode, take row 0 - tk = topk_idx[0] # (top_k,) - tk = tk.clamp(0, kv_cache.n_comp - 1) - sel_comp = kv_cache.comp_kv[tk] # (top_k, 2*hd) BF16 - # CSA compressed has 2*hd dims (a+b streams) — use as-is - all_kv = torch.cat([sel_comp, swa_kv], dim=0) # (top_k + swa_len, 2*hd) + tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1) + sel_comp = kv_cache.comp_kv[tk] + all_kv = torch.cat([sel_comp, swa_kv], dim=0) elif ratio > 4: - # HCA: all compressed entries + SWA all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0) else: all_kv = swa_kv @@ -661,202 +485,62 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, all_kv = swa_kv seq_len = all_kv.shape[0] - - # 13. Attention computation - # For MQA: K is (1, seq_len, hd), expand to n_h heads - # But CSA compressed entries may have 2*hd dims — need special handling - # - # IMPORTANT: The compressed KV has DIFFERENT dim from SWA KV! - # CSA compressed: (N, 2*hd) — need to reshape to (N, 2, hd) and handle separately - # HCA compressed: (N, hd) — same as SWA - # - # For now, since this is a reference implementation, we'll handle - # the simple case where seq < window (SWA-only attention) - # and build up the full sparse attention as we go. - # - # Actually, looking at the DeepSeek reference more carefully: - # The KV is ALWAYS head_dim=512 per token. The compressed entries - # have 2*coff*hd but coff is the compression output features, not - # the head_dim. Let me re-examine... - # - # From the reference: compressed output has shape (N, 2*coff*hd) - # where coff = ratio (4 or 128). But the attention expects (N, hd). - # So either: - # 1. The compressed output is projected back to hd before attention, or - # 2. The attention operates on the compressed representation directly - # - # Looking at the reference code more carefully: - # compressed, scores = self.compressor(hidden, positions) - # # compressed: (N, 2*coff*hd) for CSA, (N, coff*hd) for HCA - # # Then: compressed is inserted into the KV cache - # # The sparse_attn kernel handles the dual-stream attention - # - # The sparse_attn takes: - # q: (T, n_h, hd) - # kv: (T, hd) — just the raw KV latent (NOT the compressed!) - # attn_sink: (n_h,) - # topk_idxs: (T, top_k) — which compressed entries to attend to - # - # So the sparse_attn kernel internally gathers compressed KV from - # the cache using the topk_idxs! The `kv` input is just the SWA KV. - # This makes sense — the kernel does the full sparse attention with - # both SWA and compressed branches. - # - # For our Python implementation, we need to manually construct - # the KV that the attention operates over. - # - # Actually wait — looking at the reference AGAIN: - # The forward of the attention layer does: - # kv = wkv(x) # (T, hd) — raw KV for THIS token - # compressed = compressor(x, ...) # compressed KV entries - # kv_cache.append(kv) # raw KV to SWA - # kv_cache.add_compressed(compressed) - # # Then for attention: - # full_kv = gather(kv_cache, topk_idxs) - # # This gathers: compressed[topk] + swa_kv - # attn_out = sparse_attn(q, full_kv, attn_sink) - # - # The KEY insight: the compressed KV has the SAME head_dim as regular KV. - # The 2*coff in the compressor output is the internal representation - # that gets projected/reshaped before being stored in the cache. - # Let me re-examine the reference... - # - # Actually, I think I was wrong about the compressor output shape. - # Let me look at the reference compressor again: - # coff = self.coff # = ratio - # self.compression_dim = 2 * coff * self.head_dim - # wkv: nn.Linear(hidden_size, compression_dim) - # So for CSA: wkv output = (T, 2*4*512) = (T, 4096) - # For HCA: wkv output = (T, 2*128*512) = (T, 131072) — that's WAY too big - # - # Wait, 2*128*512 = 131072 — that's 128KB per token! That can't be right. - # Let me check again... - # - # Looking at the reference: - # coff = 1 # for HCA! - # coff = ratio # for CSA (4) - # - # Actually I see now: - # self.coff = 1 if compress_ratio > 4 else compress_ratio - # So for HCA: coff=1, compression_dim = 2*1*512 = 1024 = 2*hd - # For CSA: coff=4, compression_dim = 2*4*512 = 4096 = 8*hd - # - # This means the compressed KV for HCA is (N, 2*hd) — a and b streams - # even though there's only 1 compressed entry per 128 tokens. - # And for CSA it's (N, 8*hd) — 4 a-streams + 4 b-streams. - # - # But the sparse_attn kernel expects (N, hd) per entry... - # So there must be a reshape or the kernel handles multi-dim entries. - # - # Let me look at sparse_attn signature: - # def sparse_attn(q, kv, attn_sink, topk_idxs, scale): - # q: (T, n_h, hd) - # kv: (T, hd) — this is the RAW KV for the current token only! - # The kernel reads compressed KV from the cache internally. - # - # OK so the sparse_attn is a CUSTOM kernel that handles everything - # internally. Our Python implementation needs to manually do what - # that kernel does. - # - # For a Python reference, the attention is: - # 1. Build KV = [compressed_entries, swa_entries] - # 2. For compressed entries, reshape from (2*coff*hd) to (coff*2, hd) - # or handle the multi-dim properly - # 3. Attend Q against this full KV - # 4. Apply sinks - # - # For simplicity in this first pass, let's do the SWA-only attention - # for short sequences (which is mathematically correct when seq <= window) - # and add the compressed branch as we scale up. - # - # ACTUALLY — I realize I need to just implement this properly. - # The compressed KV in the cache has the same head_dim (hd=512) - # per entry. The compressor's 2*coff output features get RESHAPED - # into coff entries of 2*hd each, which then become separate - # "virtual tokens" in the KV cache. - # - # For CSA (coff=4): one compression of 4 tokens produces 4+4=8 virtual - # KV entries (4 a-stream + 4 b-stream), each of dim hd. - # For HCA (coff=1): one compression of 128 tokens produces 1+1=2 virtual - # KV entries, each of dim hd. - # - # This makes the attention straightforward: just attend over all - # virtual KV entries + SWA entries. - # - # Let me fix the compressor to output (N*coff*2, hd) instead of (N, 2*coff*hd) - # Actually, I need to re-think. Let me just use the simple approach - # for now: for short sequences, SWA attention is sufficient. - # The compressor will still run and populate the cache for future steps. - - # For short sequences, SWA-only attention is correct - all_kv = swa_kv # (swa_len, hd) BF16 - seq_len = swa_len - if seq_len == 0: - # No KV yet (first token) — return zero attention output - F_attn = torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev) - return F_attn, q_a # Also return q_lora for indexer + return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a - # MQA: expand KV to n_h heads - k_expanded = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() # (n_h, seq, hd) - v_expanded = k_expanded.clone() # K=V in DSV4 MQA - q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd) - - # Compute attention with sink logits - scores = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq) - sink_key = f"{pfx}.attn_sink" - if sink_key in w: - sinks = w[sink_key].to(device=dev) # (n_h,) BF16 + # 6. SDPA with sinks + k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() + v_exp = k_exp.clone() + q_in = q_heads.permute(1, 0, 2) + scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale + sinks = w.get(f"{pfx}.sinks") + if sinks is not None: + sinks = sinks.to(device=dev) sink_logits = sinks.float().reshape(n_h, 1, 1).expand(-1, T, 1) combined = torch.cat([scores, sink_logits], dim=-1) combined = combined - combined.max(-1, keepdim=True).values probs = torch.softmax(combined.float(), -1).bfloat16() - attn_w = probs[..., :-1] # Drop sink column + attn_w = probs[..., :-1] else: attn_w = torch.softmax(scores.float(), -1).bfloat16() - attn_out = torch.matmul(attn_w, v_expanded) # (n_h, T, hd) - attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd) + attn_out = torch.matmul(attn_w, v_exp).permute(1, 0, 2) - # Inverse RoPE + # 7. Inverse RoPE attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True) - # Output projection: wo_a (grouped BMM) + wo_b (NVFP4) - hpg = n_h // o_groups # heads per group - gid = hpg * hd # group input dim - a_flat = attn_out.reshape(T, n_h * hd) - a_grp = a_flat.reshape(T, o_groups, gid) - oa_w = w[f"{pfx}.wo_a.weight"]; oa_s = w[f"{pfx}.wo_a.scale"] - oa_bf = dequant_nvfp4(oa_w, oa_s) - oa_3d = oa_bf.reshape(o_groups, o_rank, gid) - g_out = torch.bmm(a_grp.permute(1,0,2), oa_3d.transpose(1,2)) # (g, T, o_rank) - g_flat = g_out.permute(1,0,2).reshape(T, o_groups * o_rank) - F_attn = nvfp4_linear(g_flat, w[f"{pfx}.wo_b.weight"], w[f"{pfx}.wo_b.scale"]) - - return F_attn, q_a # Return q_lora for indexer - + # 8. Output projection: wo_a (BF16 grouped BMM) + wo_b (NVFP4) + hpg = n_h // o_groups + gid = hpg * hd + oa_w = w.get(f"{pfx}.o_a_proj.weight") + if oa_w is not None: + oa_bf = oa_w.bfloat16().to(dev) + a_flat = attn_out.reshape(T, n_h * hd) + a_grp = a_flat.reshape(T, o_groups, gid) + oa_3d = oa_bf.reshape(o_groups, o_rank, gid) + g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2)) + g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) + F_attn = do_nvfp4_linear(g_flat, w, pfx, 'o_b_proj') + else: + F_attn = do_nvfp4_linear(attn_out.reshape(T, n_h * hd), w, pfx, 'o_a_proj') + return F_attn, q_a # ===================================================================== # MoE forward # ===================================================================== - def moe_forward(x, w, li, cfg, token_id, device): - """Routed MoE + shared expert. - - x: (T, H) BF16 — post-RMSNorm FFN input - Returns: (T, H) BF16 - """ H = cfg["hidden_size"] n_e = cfg["n_routed_experts"] top_k = cfg.get("num_experts_per_tok", 6) rsc = cfg.get("routed_scaling_factor", 2.5) lim = cfg.get("swiglu_limit", 10.0) - pfx = f"layers.{li}.ffn" + num_hash = cfg.get("num_hash_layers", 3) + pfx = f"model.layers.{li}.mlp" # Routing tid2eid_key = f"{pfx}.gate.tid2eid" e_bias_key = f"{pfx}.gate.e_score_correction_bias" - is_hash = tid2eid_key in w and e_bias_key not in w + is_hash = (li < num_hash) and (tid2eid_key in w) if is_hash: tid2eid = w[tid2eid_key] @@ -864,36 +548,33 @@ def moe_forward(x, w, li, cfg, token_id, device): expert_ids = tid2eid[tid] expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k else: - # Dense routing: sqrt(softplus) + e_score_correction_bias (selection only) - # Gate weight is BF16 (not NVFP4 — no .scale in checkpoint) - gate_w = w[f"{pfx}.gate.weight"].bfloat16() - logits = F.linear(x, gate_w) # (T, n_e) + gate_ww, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate') + if gate_ww is not None: + logits = nvfp4_linear(x, gate_ww.to(device), gate_ws.to(device), + gate_ws2.to(device) if gate_ws2 is not None else None, + gate_isc.to(device) if gate_isc is not None else None) + else: + gw = w.get(f"{pfx}.gate.weight") + logits = F.linear(x, gw.bfloat16().to(device)) scores = torch.sqrt(F.softplus(logits.float()) + 1e-6) - sel_logits = scores.clone() + sel = scores.clone() if e_bias_key in w: - sel_logits = sel_logits + w[e_bias_key].to(device=x.device).float().unsqueeze(0) - _, indices = sel_logits.topk(top_k, -1) + sel = sel + w[e_bias_key].to(device=x.device).float().unsqueeze(0) + _, indices = sel.topk(top_k, -1) expert_weights = torch.gather(scores, -1, indices) expert_weights = expert_weights / expert_weights.sum(-1, keepdim=True) - if x.shape[0] == 1: - expert_ids = indices[0]; expert_weights = expert_weights[0] - else: - raise NotImplementedError("Multi-token MoE routing") + expert_ids, expert_weights = indices[0], expert_weights[0] - # Run experts - T = x.shape[0] + # Routed experts expert_outs = [] for i, eid in enumerate(expert_ids): ep = f"{pfx}.experts.{eid.item()}" - g = nvfp4_linear(x, w[f"{ep}.w1.weight"], w[f"{ep}.w1.scale"]) - u = nvfp4_linear(x, w[f"{ep}.w3.weight"], w[f"{ep}.w3.scale"]) + g = do_nvfp4_linear(x, w, ep, 'gate_proj') + u = do_nvfp4_linear(x, w, ep, 'up_proj') silu = F.silu(g.float()) - if lim is not None: - silu = silu.clamp(-lim, lim) - u = u.float().clamp(-lim, lim) + if lim is not None: silu = silu.clamp(-lim, lim); u = u.float().clamp(-lim, lim) h = (silu * u).bfloat16() - d = nvfp4_linear(h, w[f"{ep}.w2.weight"], w[f"{ep}.w2.scale"]) - expert_outs.append(d) + expert_outs.append(do_nvfp4_linear(h, w, ep, 'down_proj')) routed = torch.zeros_like(x) for out, wt in zip(expert_outs, expert_weights): @@ -901,192 +582,159 @@ def moe_forward(x, w, li, cfg, token_id, device): routed = (routed.float() * rsc).bfloat16() # Shared expert - sp = f"{pfx}.shared_expert" - sg = nvfp4_linear(x, w[f"{sp}.w1.weight"], w[f"{sp}.w1.scale"]) - su = nvfp4_linear(x, w[f"{sp}.w3.weight"], w[f"{sp}.w3.scale"]) + sp = f"{pfx}.shared_experts" + sg = do_nvfp4_linear(x, w, sp, 'gate_proj') + su = do_nvfp4_linear(x, w, sp, 'up_proj') silu = F.silu(sg.float()) if lim is not None: silu = silu.clamp(-lim, lim); su = su.float().clamp(-lim, lim) - sh = (silu * su).bfloat16() - shared = nvfp4_linear(sh, w[f"{sp}.w2.weight"], w[f"{sp}.w2.scale"]) - + shared = do_nvfp4_linear((silu * su).bfloat16(), w, sp, 'down_proj') return routed + shared - # ===================================================================== # Layer forward # ===================================================================== - def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc, attn_norm_w, ffn_norm_w, kv_cache, positions, token_id, compressor=None, indexer=None): - """Forward one transformer layer. - - X_l: (T, n_hc, H) BF16 — mHC residual state - Returns: X_next (T, n_hc, H) BF16 - """ dev = X_l.device - H = cfg["hidden_size"] - pfx = f"layers.{li}" - - # -- Attention sub-block -- + # Attention sub-block x_in, ctx_a = attn_mhc.pre_block(X_l) x_normed = rmsnorm(x_in, attn_norm_w) - - F_attn, q_lora = forward_attention( - x_normed, w, li, cfg, rope_cos, rope_sin, - kv_cache, positions, compressor, indexer) + F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, + kv_cache, positions, compressor, indexer) X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a) - - # -- FFN sub-block -- + # FFN sub-block x_in_f, ctx_f = ffn_mhc.pre_block(X_mid) x_ffn = rmsnorm(x_in_f, ffn_norm_w) F_ffn = moe_forward(x_ffn, w, li, cfg, token_id, dev) X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f) - if GROWTH_DIAG: print(f" L{li}: |X|={X_l.abs().max().item():.1f}→{X_next.abs().max().item():.1f} " f"|Fa|={F_attn.abs().max().item():.1f} |Ff|={F_ffn.abs().max().item():.1f}", flush=True) - return X_next - # ===================================================================== # Main # ===================================================================== - def main(): t0 = time.time() torch.manual_seed(SEED) - print("="*70) + print("=" * 70) print("DSV4 Single-Shot Inference — Full E2E Pipeline") - print(" mHC + Compressor + Indexer + Attention + MoE") - print("="*70) + print(" NVFP4 two-level scale | Compressor + Indexer | mHC | MoE") + print("=" * 70) with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) n_layers = cfg["num_hidden_layers"] H = cfg["hidden_size"] - n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg.get("qk_rope_head_dim", 64) - compress_ratios = cfg.get("compress_ratios", [128]*61) - print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}") - print(f"Compress ratios: {compress_ratios[:5]}... (len={len(compress_ratios)})") + cr = cfg.get("compress_ratios", [128] * 61) + print(f"Model: {n_layers} layers, {cfg['num_attention_heads']} heads, hd={hd}, rope_dim={rd}") + print(f"Compress ratios: first5={cr[:5]} len={len(cr)}") print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}") - # Phase 1: Load weights + # Load weights print(f"\nPhase 1: Loading weights...") all_w = load_weights(CHECKPOINT_DIR) print(f" {time.time()-t0:.1f}s") - # Build mHC + norms + # mHC + norms print("Building mHC blocks and norms...") - attn_mhcs = {}; ffn_mhcs = {}; attn_norms = {}; ffn_norms = {} + attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {} for li in range(n_layers): dev = f"cuda:{li % NUM_GPUS}" for tag, blocks, fn_s, base_s, scale_s in [ - ("attn", attn_mhcs, f"layers.{li}.hc_attn_fn", f"layers.{li}.hc_attn_base", f"layers.{li}.hc_attn_scale"), - ("ffn", ffn_mhcs, f"layers.{li}.hc_ffn_fn", f"layers.{li}.hc_ffn_base", f"layers.{li}.hc_ffn_scale"), + ("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn", + f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"), + ("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn", + f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"), ]: - if fn_s in all_w and base_s in all_w and scale_s in all_w: + fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s) + if fn is not None and base is not None and scale is not None: m = mHCBlock(H, 4, 20, dev) - m.load(all_w[fn_s], all_w[base_s], all_w[scale_s]) + m.load(fn.to(dev, torch.float32), base.to(dev, torch.float32), scale.to(dev, torch.float32)) blocks[li] = m else: - print(f" WARNING: no mHC for layers.{li}.{tag}") + print(f" WARNING: no mHC for L{li} {tag}") - # RMSNorms - an_k = f"layers.{li}.attn_norm.weight" - if an_k in all_w: - attn_norms[li] = all_w[an_k].to(dev, torch.float32) - fn_k = f"layers.{li}.ffn_norm.weight" - if fn_k in all_w: - ffn_norms[li] = all_w[fn_k].to(dev, torch.float32) + an_k = f"model.layers.{li}.input_layernorm.weight" + if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32) + fn_k = f"model.layers.{li}.post_attention_layernorm.weight" + if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32) # Global weights torch.cuda.set_device(0) - embed_w = all_w.get("embed.weight", all_w.get("model.embed_tokens.weight")) + embed_w = all_w.get("model.embed_tokens.weight") embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0')) - lm_k = "head.weight" if "head.weight" in all_w else "lm_head.weight" - lm_w = all_w.get(lm_k, embed_w).bfloat16().to('cuda:0') - final_norm_w = all_w.get("norm.weight") + lm_w = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0') + final_norm_w = all_w.get("model.norm.weight") if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32) - # HcHead hc_head = HcHead(H, 4, 'cuda:0') - hc_fn = all_w.get("hc_head_fn") - hc_base = all_w.get("hc_head_base") - hc_scale = all_w.get("hc_head_scale") + hc_fn = all_w.get("model.hc_head.hc_fn") + hc_base = all_w.get("model.hc_head.hc_base") + hc_scale = all_w.get("model.hc_head.hc_scale") if hc_fn is not None and hc_base is not None: hc_head.load(hc_fn, hc_base, hc_scale) - print(f" hc_head loaded") + print(" hc_head loaded") else: print(" WARNING: hc_head not found") hc_head = None - # RoPE caches - rope_params = cfg.get("rope_parameters", {}) - rope_type = rope_params.get("rope_type", "yarn") - rope_factor = rope_params.get("factor", 16.0) - rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.)) - orig_max = rope_params.get("original_max_position_embeddings", 4096) - beta_fast = rope_params.get("beta_fast", 32) - beta_slow = rope_params.get("beta_slow", 1) - print(f"RoPE: {rope_type} factor={rope_factor} theta={rope_theta}") - rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}", rope_theta, - rope_type, rope_factor, orig_max, beta_fast, beta_slow) - for g in range(NUM_GPUS)} + # RoPE + rp = cfg.get("rope_scaling", cfg.get("rope_parameters", {})) + rt = rp.get("type", rp.get("rope_type", "yarn")) + rf = rp.get("factor", 16.0) + rtheta = cfg.get("rope_theta", 10000.) + romax = rp.get("original_max_position_embeddings", 65536) + rbfast, rbslow = rp.get("beta_fast", 32), rp.get("beta_slow", 1) + print(f"RoPE: {rt} factor={rf} theta={rtheta} orig_max={romax}") + rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}", rtheta, rt, rf, romax, rbfast, rbslow) + for g in range(NUM_GPUS)} # KV caches - kv_caches = {li: KVCache(hd, 128, 8192, f"cuda:{li % NUM_GPUS}") for li in range(n_layers)} + kv_caches = {li: KVCache(hd, cfg.get("sliding_window", 128), f"cuda:{li % NUM_GPUS}") + for li in range(n_layers)} - # Compressors + indexers (persistent per layer) - compressors = {}; indexers = {} + # Compressors + indexers + compressors, indexers = {}, {} + n_ih = cfg.get("index_n_heads", 64) + ihd = cfg.get("index_head_dim", 128) + itk = cfg.get("index_topk", 1024) for li in range(n_layers): dev = f"cuda:{li % NUM_GPUS}" - ratio = compress_ratios[li] if li < len(compress_ratios) else 128 - if ratio > 0: - c = Compressor(ratio, hd, H, dev) - # Load from cached weights (already on device) - # We'll load after caching layer weights - compressors[li] = c - if ratio == 4: # CSA layers have indexers - # Indexer head dim and heads — from checkpoint shapes - # We'll determine these from weight shapes after loading - indexers[li] = Indexer(1, 128, 512, dev) # n_ih, ihd, top_k — will fix from shapes + ratio = cr[li] if li < len(cr) else 128 + if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) + if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev) - # Cache layer weights + # Cache layer weights to GPUs print("Caching layer weights to GPUs...") devs = [f"cuda:{g}" for g in range(NUM_GPUS)] layer_w = cache_layer_weights(all_w, n_layers, devs) del all_w; import gc; gc.collect() print(f" {time.time()-t0:.1f}s") - # Load compressor/indexer weights from cached per-layer weights + # Load compressor/indexer weights for li in range(n_layers): - w = layer_w[li] - pfx = f"layers.{li}.attn" - if li in compressors: - compressors[li].load(w, f"{pfx}.compressor") - if li in indexers: - indexers[li].load(w, f"{pfx}.indexer") - print(f" Compressors/indexers loaded") + pfx = f"model.layers.{li}.self_attn.compressor" + if li in compressors: compressors[li].load(layer_w[li], pfx) + if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer") + print(" Compressors/indexers loaded") # Phase 2: Inference print(f"\nPhase 2: Inference") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) - # Build input: <|User|> prompt <|Assistant|> bos = tokenizer.bos_token_id or 0 input_ids = [bos, USER_TOKEN] input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) input_ids.append(ASSISTANT_TOKEN) - input_ids = torch.tensor([input_ids], dtype=torch.long).cuda() - print(f"Input: {input_ids.shape[1]} tokens") - - generated = input_ids[0].tolist() + generated = input_ids.copy() + print(f"Input: {len(generated)} tokens") # Prefill print(f"Prefilling {len(generated)} tokens...") @@ -1094,65 +742,52 @@ def main(): t1 = time.time() tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') pos = torch.tensor([pi], dtype=torch.long, device='cuda:0') - emb = embed(tid) - X = mHCBlock.init_state(emb) - + X = mHCBlock.init_state(embed(tid)) for li in range(n_layers): gpu = li % NUM_GPUS - dev = f"cuda:{gpu}" - if X.device != torch.device(dev): X = X.to(dev) + if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") torch.cuda.set_device(gpu) X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], - attn_mhcs.get(li), ffn_mhcs.get(li), - attn_norms.get(li), ffn_norms.get(li), - kv_caches[li], pos, tid, - compressors.get(li), indexers.get(li)) - + attn_mhcs.get(li), ffn_mhcs.get(li), + attn_norms.get(li), ffn_norms.get(li), + kv_caches[li], pos, tid, + compressors.get(li), indexers.get(li)) X = X.to('cuda:0'); torch.cuda.set_device(0) if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True) - print(f" Prefill done ({time.time()-t0:.1f}s)") # Decode print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...") all_tokens = generated.copy() - for step in range(MAX_NEW_TOKENS): t1 = time.time() tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0') dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0') - - emb = embed(tid) - X = mHCBlock.init_state(emb) - + X = mHCBlock.init_state(embed(tid)) for li in range(n_layers): gpu = li % NUM_GPUS - dev = f"cuda:{gpu}" - if X.device != torch.device(dev): X = X.to(dev) + if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") torch.cuda.set_device(gpu) X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], - attn_mhcs.get(li), ffn_mhcs.get(li), - attn_norms.get(li), ffn_norms.get(li), - kv_caches[li], dec_pos, tid, - compressors.get(li), indexers.get(li)) - + attn_mhcs.get(li), ffn_mhcs.get(li), + attn_norms.get(li), ffn_norms.get(li), + kv_caches[li], dec_pos, tid, + compressors.get(li), indexers.get(li)) X = X.to('cuda:0'); torch.cuda.set_device(0) - - # HcHead readout - x_out = hc_head.forward(X) if hc_head else X[:, 0, :] - if final_norm_w is not None: - x_out = rmsnorm(x_out, final_norm_w) + x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :] + if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w) logits = F.linear(x_out, lm_w) - next_id = torch.argmax(logits, -1).item() all_tokens.append(next_id) - tok_str = tokenizer.decode([next_id]) dt = time.time() - t1 has_nan = torch.isnan(logits.float()).any().item() if step % 5 == 0 or has_nan: tv, ti = torch.topk(logits[0], 5) - top5 = ' '.join(f'{tokenizer.decode([t.item()])}({v.item():.1f})' for t,v in zip(ti[:5],tv[:5])) - print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits=[{logits.float().min().item():.1f},{logits.float().max().item():.1f}] nan={has_nan} |X|={X.abs().max().item():.1f} top5: {top5}", flush=True) + top5 = ' '.join(f'{tokenizer.decode([t.item()])}({v.item():.1f})' + for t, v in zip(ti[:5], tv[:5])) + print(f" Step {step}: {next_id} '{tokenizer.decode([next_id])}' ({dt:.2f}s) " + f"logits=[{logits.float().min().item():.1f},{logits.float().max().item():.1f}] " + f"nan={has_nan} |X|={X.abs().max().item():.1f} top5: {top5}", flush=True) if has_nan: break if next_id == tokenizer.eos_token_id: break @@ -1163,6 +798,5 @@ def main(): print(f"Total: {time.time()-t0:.1f}s") print(f"{'='*70}") - if __name__ == "__main__": - main() + main() \ No newline at end of file