From d003c4b7ccf6ee48c8704c494dbf53ddd38f520b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 23:45:18 +0000 Subject: [PATCH] Add mHC (Manifold-Constrained Hyper-Connections) to single_shot MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Full mHC pre_block/post_block with Sinkhorn-Knopp normalization - Dynamic A_l (sigmoid), B_l (Birkhoff polytope), C_l (2*sigmoid) - Checkpoint: attn_hc.fn (24,28672) + base (24,) + scale (3,) - Two mHC blocks per layer: attn_hc + ffn_hc - Removed emergency RMSNorm — mHC handles normalization properly - X_l: (1, n_hc=4, H) residual state, init from embedding broadcast --- single_shot_inference.py | 448 ++++++++++++++++++++++----------------- 1 file changed, 255 insertions(+), 193 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 1fb8ed80..b7479c85 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1,14 +1,19 @@ #!/usr/bin/env python3 -"""Single-shot DSV4 inference — 8-GPU pipeline parallel. +"""Single-shot DSV4 inference — 8-GPU pipeline parallel with mHC. -Loads the full NVFP4 checkpoint across 8 B200 GPUs (round-robin layer assignment). -Each GPU holds ~8 layers of weights in HBM. Activations move between GPUs at -layer boundaries via cudaMemcpy (fast on NVLink). +Loads the full NVFP4 checkpoint across 8 B200 GPUs. Includes: +- mHC (Manifold-Constrained Hyper-Connections) — load-bearing residual +- Q low-rank projection + KV projection +- RoPE (partial, last 64 dims) +- Production FMHA kernel (tcgen05 MMA, hd=512, 128 heads) +- Output projection: wo_a (grouped BMM) → wo_b (NVFP4) +- Shared expert FFN (SwiGLU) +- NVFP4 dequant → BF16 matmul baseline for linear layers -Pipeline: - 1. Load all 95 shards, assign each layer's weights to its GPU - 2. JIT-compile kernels (one-time) - 3. Decode loop: embed → layer 0 (gpu0) → layer 1 (gpu0) → ... → layer 8 (gpu1) → ... → norm → lm_head +Missing (causing incorrect output): +- Routed MoE experts (384 experts, top-6) +- KV cache across decode steps +- Compressor + indexer (CSA/HCA compressed KV) Usage (on B200): source /root/dsv4-nvfp4-workspace/venv/bin/activate @@ -31,42 +36,127 @@ NUM_GPUS = 8 FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.]) def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2): - """Dequantize NVFP4 weight to BF16. All tensors must be on same device.""" out_dim = weight.shape[0] in_packed = weight.shape[1] in_features = in_packed * 2 - low = (weight & 0x0F).to(torch.int8) high = (weight >> 4).to(torch.int8) - - low_sign = (low >> 3).bool() - low_idx = (low & 0x07).long() - high_sign = (high >> 3).bool() - high_idx = (high & 0x07).long() - + low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() + high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) - w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) - scale_f = weight_scale.float() * weight_scale_2.float() scale_expanded = scale_f.repeat_interleave(16, dim=1) - return (w_f * scale_expanded).bfloat16() - def nvfp4_linear(x, weight, weight_scale, weight_scale_2): - """NVFP4 linear: dequant → BF16 matmul.""" w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2) return torch.nn.functional.linear(x, w) - def bf16_linear(x, weight): - """BF16 linear.""" return torch.nn.functional.linear(x, weight.bfloat16()) +# ===================================================================== +# mHC — Manifold-Constrained Hyper-Connections +# ===================================================================== + +def sinkhorn_knopp(M, t_max=20, eps=1e-6): + """Project (T, n, n) positive matrices onto Birkhoff polytope.""" + for _ in range(t_max): + M = M / (M.sum(dim=-1, keepdim=True) + eps) + M = M / (M.sum(dim=-2, keepdim=True) + eps) + return M + + +class mHCBlock: + """One mHC block (attention or FFN). + + Checkpoint weight mapping: + fn: (24, 28672) FP32 = stacked [W_pre(4,K); W_res(16,K); W_post(4,K)] + base: (24,) FP32 = bias, split as [S_pre(4); S_res(16); S_post(4)] + scale: (3,) FP32 = [alpha_pre, alpha_res, alpha_post] + """ + def __init__(self, hidden_dim=7168, n_hc=4, t_max=20, device='cuda'): + self.d = hidden_dim + self.n_hc = n_hc + self.K = n_hc * hidden_dim # 28672 + self.t_max = t_max + self.device = device + self.W_stacked = None # (24, K) FP32 + self.bias = None # (24,) FP32 + self.alphas = None # (3,) FP32 + + def load_from_checkpoint(self, fn, base, scale): + """Load from checkpoint tensors. All on target device, FP32.""" + self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous() + self.bias = base.to(device=self.device, dtype=torch.float32).contiguous() + self.alphas = scale.to(device=self.device, dtype=torch.float32).contiguous() + + def _dynamic_params(self, X_l): + """Compute A_l, B_l, C_l from residual state. + + X_l: (T, n_hc, d) BF16 + Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc) + """ + T, n, d = X_l.shape + n_hc = self.n_hc + + # Flatten and project with RMSNorm + X_flat = X_l.reshape(T, self.K) # (T, K) BF16 + + # RMSNorm + x_f32 = X_flat.float() + rms = x_f32.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + x_normed = x_f32 * rms # (T, K) FP32 + + # Project: (T, K) @ (24, K)^T → (T, 24) + proj = torch.nn.functional.linear(x_normed.bfloat16(), self.W_stacked.bfloat16()).float() + proj = proj + self.bias.unsqueeze(0) # add bias + + # Split into A, B, C + i0, i1, i2 = n_hc, n_hc + n_hc * n_hc, 24 + A_raw = proj[:, :i0] # (T, 4) + B_raw = proj[:, i0:i1] # (T, 16) + C_raw = proj[:, i1:i2] # (T, 4) + + # Split bias into S_pre, S_res, S_post + S_pre = self.bias[:n_hc] + S_res = self.bias[n_hc:n_hc + n_hc * n_hc] + S_post = self.bias[n_hc + n_hc * n_hc:] + + # Apply gating + biases + a_pre, a_res, a_post = self.alphas[0], self.alphas[1], self.alphas[2] + A_tilde = a_pre * A_raw + S_pre.unsqueeze(0) + B_tilde = a_res * B_raw + S_res.unsqueeze(0) + C_tilde = a_post * C_raw + S_post.unsqueeze(0) + + # Constraints + A_l = torch.sigmoid(A_tilde).bfloat16() # (T, 4) ∈ (0,1) + C_l = (2.0 * torch.sigmoid(C_tilde)).bfloat16() # (T, 4) ∈ (0,2) + B_exp = torch.exp(B_tilde).reshape(T, n_hc, n_hc) # (T, 4, 4) + B_l = sinkhorn_knopp(B_exp, self.t_max) # FP32, doubly stochastic + + return A_l, B_l, C_l + + def pre_block(self, X_l): + """X_l: (T, n_hc, d) → x_in: (T, d), ctx""" + A_l, B_l, C_l = self._dynamic_params(X_l) + # x_in = A_l @ X_l: (T, 1, n_hc) bmm (T, n_hc, d) → (T, 1, d) → (T, d) + x_in = torch.bmm(A_l.unsqueeze(1).float(), X_l.float()).squeeze(1).bfloat16() + return x_in, (B_l, C_l) + + def post_block(self, X_l, F_out, ctx): + """X_l: (T, n_hc, d), F_out: (T, d) → X_next: (T, n_hc, d)""" + B_l, C_l = ctx + # X_next = B_l @ X_l + C_l ⊗ F_out + BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32 + CF = C_l.unsqueeze(-1).float() * F_out.unsqueeze(1).float() # (T, n_hc, d) FP32 + return (BX + CF).bfloat16() + + # ===================================================================== # RoPE # ===================================================================== @@ -74,110 +164,73 @@ def bf16_linear(x, weight): def build_rope_cache(max_pos, head_dim, rope_dim, device, theta=10000.0): half = rope_dim // 2 freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) - positions = torch.arange(max_pos, dtype=torch.float32) - angles = torch.outer(positions, freqs) + angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) return torch.cos(angles).to(device), torch.sin(angles).to(device) - def apply_rope(x, positions, cos_cache, sin_cache, rope_dim): T, n_h, hd = x.shape nope = hd - rope_dim - half = rope_dim // 2 - cos = cos_cache[positions].unsqueeze(1).to(x.dtype) sin = sin_cache[positions].unsqueeze(1).to(x.dtype) - - x_rope = x[:, :, nope:] - even = x_rope[:, :, 0::2] - odd = x_rope[:, :, 1::2] - out = x.clone() - out[:, :, nope:][..., 0::2] = even * cos - odd * sin - out[:, :, nope:][..., 1::2] = even * sin + odd * cos + out[:, :, nope:][..., 0::2] = x[:, :, nope:][..., 0::2] * cos - x[:, :, nope:][..., 1::2] * sin + out[:, :, nope:][..., 1::2] = x[:, :, nope:][..., 0::2] * sin + x[:, :, nope:][..., 1::2] * cos return out # ===================================================================== -# Checkpoint loading — load all shards, group by layer, assign to GPU +# Weight loading # ===================================================================== def load_all_weights(checkpoint_dir, num_layers): - """Load all 95 shards and organize weights by layer, moving to target GPU. - - Returns: dict mapping layer_idx → dict of weight tensors (on target GPU) - plus global weights (embed, norm, lm_head) on gpu0 - """ from safetensors.torch import load_file - from collections import defaultdict cdir = Path(checkpoint_dir) - - # Load the index for fast shard lookup index_path = cdir / "model.safetensors.index.json" + weight_map = {} if index_path.exists(): with open(index_path) as f: weight_map = json.load(f).get("weight_map", {}) - else: - weight_map = {} - # Organize: which shard files do we need? shard_names = set(weight_map.values()) if weight_map else { f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96) } - # Load all shards (one at a time to limit CPU RAM) - print(f"Loading {len(shard_names)} shards from checkpoint...") - all_weights = {} # key → tensor (CPU) + print(f"Loading {len(shard_names)} shards...") + all_weights = {} loaded = 0 for shard_name in sorted(shard_names): - shard_path = cdir / shard_name - if not shard_path.exists(): + if not (cdir / shard_name).exists(): continue - data = load_file(str(shard_path)) + data = load_file(str(cdir / shard_name)) all_weights.update(data) loaded += 1 - if loaded % 10 == 0: - print(f" Loaded {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors") + if loaded % 20 == 0: + print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors") + print(f" Done: {len(all_weights)} tensors") - print(f" Done: {len(all_weights)} tensors loaded to CPU") + layer_weights = {} + global_weights = {} - # Group by layer and assign to GPU - # Layer i goes to GPU (i // 8) — 8 layers per GPU - # Actually: round-robin — layer i goes to GPU (i % 8) - # This balances load better (8 layers per GPU, ~15GB each) - - layer_weights = {} # layer_idx → dict of tensors on target GPU - global_weights = {} # embed, norm, lm_head → on gpu0 - - print("Assigning layers to GPUs...") + print("Assigning to GPUs...") for key, tensor in all_weights.items(): - # Determine which layer this weight belongs to if key.startswith("model.layers."): - parts = key.split(".") - layer_idx = int(parts[2]) - target_gpu = layer_idx % NUM_GPUS - - if layer_idx not in layer_weights: - layer_weights[layer_idx] = {"_device": f"cuda:{target_gpu}"} - - # Move to target GPU - layer_weights[layer_idx][key] = tensor.to(f"cuda:{target_gpu}") - + li = int(key.split(".")[2]) + target_gpu = li % NUM_GPUS + target_device = f"cuda:{target_gpu}" + if li not in layer_weights: + layer_weights[li] = {"_device": target_device, "_gpu": target_gpu} + layer_weights[li][key] = tensor.to(target_device) elif key.startswith("model.embed_tokens"): global_weights[key] = tensor.to("cuda:0") - elif key.startswith("model.norm"): global_weights[key] = tensor.to("cuda:0") - elif key.startswith("lm_head"): global_weights[key] = tensor.to("cuda:0") - # Print per-GPU memory usage for gpu in range(NUM_GPUS): - torch.cuda.set_device(gpu) - allocated = torch.cuda.memory_allocated(gpu) / 1e9 - reserved = torch.cuda.memory_reserved(gpu) / 1e9 - print(f" GPU {gpu}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved") + alloc = torch.cuda.memory_allocated(gpu) / 1e9 + print(f" GPU {gpu}: {alloc:.1f}GB") return layer_weights, global_weights @@ -186,95 +239,95 @@ def load_all_weights(checkpoint_dir, num_layers): # Single layer forward # ===================================================================== -def forward_layer(x, w, li, cfg, rope_cos, rope_sin): - """Forward one layer. x: (1, hidden) BF16 → (1, hidden) BF16.""" - device = x.device +def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc): + """Forward one layer with mHC. + + X_l: (1, n_hc, H) BF16 → (1, n_hc, H) BF16 + """ + device = X_l.device H = cfg["hidden_size"] n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg["qk_rope_head_dim"] o_rank = cfg["o_lora_rank"] o_groups = cfg["o_groups"] + n_hc = 4 pre = f"model.layers.{li}.self_attn" - T = x.shape[0] - heads_per_group = n_h // o_groups # 8 - group_input_dim = heads_per_group * hd # 4096 + T = X_l.shape[0] + heads_per_group = n_h // o_groups + group_input_dim = heads_per_group * hd - # ---- Q projection: q_a (down) → q_b (up) ---- - c_Q = nvfp4_linear(x, w[f"{pre}.q_a_proj.weight"], + # ==== mHC pre_block (attention) ==== + x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) BF16 + + # ==== Q projection ==== + c_Q = nvfp4_linear(x_in, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) - # ---- KV projection ---- - kv = nvfp4_linear(x, w[f"{pre}.kv_proj.weight"], + # ==== KV projection ==== + kv = nvfp4_linear(x_in, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) - # ---- Reshape for attention ---- - q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd) - k = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd) — MQA + # ==== Reshape for attention ==== + q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) + k = kv.reshape(T, 1, hd).permute(1, 0, 2) v = k.clone() - # ---- Apply RoPE ---- + # ==== RoPE ==== pos = torch.tensor([0], dtype=torch.long, device=device) q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd) k = apply_rope(k, pos, rope_cos, rope_sin, rd) - # ---- FMHA ---- + # ==== FMHA ==== from dsv4.kernels.attention.production import dsv4_attention - attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd) - attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd) - - # ---- Output projection: wo_a (BF16 grouped BMM) → wo_b (NVFP4) ---- - attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd) - attn_grouped = attn_grouped.reshape(T, o_groups, group_input_dim) + attn_out = dsv4_attention(q_heads, k, v) + attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) + # ==== Output projection ==== + attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd).reshape(T, o_groups, group_input_dim) oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) - - attn_for_bmm = attn_grouped.permute(1, 0, 2) # (16, 1, 4096) - grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (16, 1, o_rank) + attn_for_bmm = attn_grouped.permute(1, 0, 2) + grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) + F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], + w[f"{pre}.o_b_proj.weight_scale"], + w[f"{pre}.o_b_proj.weight_scale_2"]) - attn_proj = nvfp4_linear(grouped_flat, - w[f"{pre}.o_b_proj.weight"], - w[f"{pre}.o_b_proj.weight_scale"], - w[f"{pre}.o_b_proj.weight_scale_2"]) + # ==== mHC post_block (attention) ==== + X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H) - # ---- Residual + emergency RMSNorm (mHC missing) ---- - x = x + attn_proj - x = x.clamp(-65504, 65504) - xf = x.float() - rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - x = (xf * rms).bfloat16() + # ==== mHC pre_block (FFN) ==== + x_ffn, ffn_ctx = ffn_mhc.pre_block(X_l) - # ---- FFN: shared expert ---- + # ==== FFN: shared expert ==== se_pre = f"model.layers.{li}.mlp.shared_experts" se_gate_w = w.get(f"{se_pre}.gate_proj.weight") + F_ffn = torch.zeros_like(x_ffn) if se_gate_w is not None: - gate = nvfp4_linear(x, se_gate_w, + gate = nvfp4_linear(x_ffn, se_gate_w, w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"]) - up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"], + up = nvfp4_linear(x_ffn, w[f"{se_pre}.up_proj.weight"], w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"]) - ffn_out = nvfp4_linear( + F_ffn = nvfp4_linear( torch.nn.functional.silu(gate) * up, w[f"{se_pre}.down_proj.weight"], w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"], ) - x = x + ffn_out - x = x.clamp(-65504, 65504) - xf = x.float() - rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - x = (xf * rms).bfloat16() - return x + # ==== mHC post_block (FFN) ==== + X_l = ffn_mhc.post_block(X_l, F_ffn, ffn_ctx) + + return X_l # ===================================================================== @@ -284,10 +337,9 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin): def main(): t_start = time.time() print("=" * 70) - print("DSV4 Single-Shot Inference — 8-GPU Pipeline Parallel") + print("DSV4 Single-Shot Inference — 8-GPU with mHC") print("=" * 70) - # Config with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) n_layers = cfg["num_hidden_layers"] @@ -295,58 +347,69 @@ def main(): n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg["qk_rope_head_dim"] + n_hc = 4 print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}") - print(f"GPUs: {NUM_GPUS}, ~{n_layers // NUM_GPUS} layers per GPU") - # ---- Phase 1: Load weights ---- - print(f"\n{'='*70}") - print("Phase 1: Loading weights across 8 GPUs") - print(f"{'='*70}") + # ==== Phase 1: Load weights ==== + print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}") layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers) t_loaded = time.time() print(f"Weight loading: {t_loaded - t_start:.1f}s") - # Set default device to gpu0 for embed/lm_head - torch.cuda.set_device(0) + # ==== Build mHC blocks per layer ==== + print("Building mHC blocks...") + attn_mhc_blocks = {} + ffn_mhc_blocks = {} + for li in range(n_layers): + gpu = li % NUM_GPUS + dev = f"cuda:{gpu}" + + # Attention mHC + attn_fn = layer_weights[li].get(f"model.layers.{li}.attn_hc.fn") + attn_base = layer_weights[li].get(f"model.layers.{li}.attn_hc.base") + attn_scale = layer_weights[li].get(f"model.layers.{li}.attn_hc.scale") + if attn_fn is not None and attn_base is not None and attn_scale is not None: + attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) + attn_mhc.load_from_checkpoint(attn_fn, attn_base, attn_scale) + attn_mhc_blocks[li] = attn_mhc + + # FFN mHC + ffn_fn = layer_weights[li].get(f"model.layers.{li}.ffn_hc.fn") + ffn_base = layer_weights[li].get(f"model.layers.{li}.ffn_hc.base") + ffn_scale = layer_weights[li].get(f"model.layers.{li}.ffn_hc.scale") + if ffn_fn is not None and ffn_base is not None and ffn_scale is not None: + ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) + ffn_mhc.load_from_checkpoint(ffn_fn, ffn_base, ffn_scale) + ffn_mhc_blocks[li] = ffn_mhc - # Embedding on gpu0 + print(f" attn mHC: {len(attn_mhc_blocks)} layers") + print(f" ffn mHC: {len(ffn_mhc_blocks)} layers") + + # ==== Global weights (gpu0) ==== + torch.cuda.set_device(0) embed_w = global_weights.get("model.embed_tokens.weight") embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16()) - - # lm_head on gpu0 lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16() - - # Final norm on gpu0 final_norm_w = global_weights.get("model.norm.weight") - # Build RoPE caches — one per GPU - rope_caches = {} - for gpu in range(NUM_GPUS): - rope_caches[gpu] = build_rope_cache(8192, hd, rd, f"cuda:{gpu}") + # RoPE caches per GPU + rope_caches = {g: build_rope_cache(8192, hd, rd, f"cuda:{g}") for g in range(NUM_GPUS)} - # ---- Phase 2: JIT compile kernels ---- - print(f"\n{'='*70}") - print("Phase 2: JIT compiling kernels") - print(f"{'='*70}") - # Trigger FMHA kernel compile on gpu0 with a dummy forward - # This compiles the C API .so and caches it for all subsequent calls + # ==== Phase 2: Compile ==== + print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}") from dsv4.kernels.attention.production import dsv4_attention dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0') dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0') - dummy_v = dummy_k.clone() try: - _ = dsv4_attention(dummy_q, dummy_k, dummy_v) - print(" FMHA kernel: compiled OK") + _ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone()) + print(" FMHA: compiled OK") except Exception as e: - print(f" FMHA kernel compile error: {e}") + print(f" FMHA error: {e}") t_compiled = time.time() - print(f"Kernel compilation: {t_compiled - t_loaded:.1f}s") - - # ---- Phase 3: Inference ---- - print(f"\n{'='*70}") - print("Phase 3: Inference") - print(f"{'='*70}") + print(f"Compile: {t_compiled - t_loaded:.1f}s") + # ==== Phase 3: Inference ==== + print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda() @@ -358,60 +421,59 @@ def main(): t0 = time.time() tid = torch.tensor([generated[-1]], dtype=torch.long, device='cuda:0') - # Embed (gpu0) - x = embed(tid) # (1, H) on gpu0 + # Embed → mHC init state + emb = embed(tid) # (1, H) on gpu0 + X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, n_hc, H) - # Process layers — move x to the right GPU for each layer + # Process layers for li in range(n_layers): - target_gpu = li % NUM_GPUS - target_device = f"cuda:{target_gpu}" + gpu = li % NUM_GPUS + target_device = f"cuda:{gpu}" + if X.device != torch.device(target_device): + X = X.to(target_device) + torch.cuda.set_device(gpu) - # Move activation to layer's GPU and set as active device - if x.device != torch.device(target_device): - x = x.to(target_device) - torch.cuda.set_device(target_gpu) - - lw = layer_weights[li] - rc, rs = rope_caches[target_gpu] - x = forward_layer(x, lw, li, cfg, rc, rs) + attn_mhc = attn_mhc_blocks.get(li) + ffn_mhc = ffn_mhc_blocks.get(li) + rc, rs = rope_caches[gpu] + X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, attn_mhc, ffn_mhc) - # Move back to gpu0 for final norm + lm_head - x = x.to('cuda:0') + # Back to gpu0 + X = X.to('cuda:0') + torch.cuda.set_device(0) - # Final norm + # Read out stream 0 → RMSNorm → lm_head + x_out = X[:, 0, :] # (1, H) if final_norm_w is not None: - xf = x.float() + xf = x_out.float() rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - x = (xf * rms * final_norm_w.float()).bfloat16() + x_out = (xf * rms * final_norm_w.float()).bfloat16() - # lm_head - logits = torch.nn.functional.linear(x, lm_w) + logits = torch.nn.functional.linear(x_out, lm_w) next_id = torch.argmax(logits, dim=-1).item() generated.append(next_id) tok_str = tokenizer.decode([next_id]) dt = time.time() - t0 - if step == 0: - print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) [first step includes compile cache]") - else: - print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s)") + # Check for NaN + has_nan = torch.isnan(logits.float()).any().item() + logit_range = f"[{logits.float().min().item():.1f}, {logits.float().max().item():.1f}]" + print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits={logit_range} nan={has_nan}") + + if has_nan: + print(" NaN detected, stopping") + break if next_id == tokenizer.eos_token_id: break - # ---- Output ---- out = tokenizer.decode(generated, skip_special_tokens=True) - total_time = time.time() - t_start + total = time.time() - t_start print(f"\n{'='*70}") print(f"Input: '{PROMPT}'") print(f"Output: '{out}'") - print(f"Total time: {total_time:.1f}s (load: {t_loaded-t_start:.1f}s, compile: {t_compiled-t_loaded:.1f}s, inference: {time.time()-t_compiled:.1f}s)") + print(f"Total: {total:.1f}s (load: {t_loaded-t_start:.1f}s, compile: {t_compiled-t_loaded:.1f}s, infer: {time.time()-t_compiled:.1f}s)") print(f"{'='*70}") - print() - print("KERNEL VERIFICATION: PASSED") - print("FMHA produces correct, finite output at hd=512, 128 query heads") - print("across all 61 layers. Garbage output is an architecture gap") - print("(missing mHC + MoE + KV cache), not a kernel issue.") if __name__ == "__main__":