From 715602c87c19a6908e5f7bc474f0309f115ba95c Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 14:10:28 +0000 Subject: [PATCH] Switch lm_head to BF16 + router gate to FP8_E4M3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit lm_head: BF16 F.linear (checkpoint weight is BF16, no quantization) Router gate: FP8_E4M3 quantize→dequantize round-trip, then F.linear - Dequantize NVFP4 checkpoint weights to BF16 first - Quantize to FP8_E4M3 (scale = amax/448) - Dequantize back to BF16 for F.linear - Uses BF16 dispatch path in dense_router_dispatch - Simpler scale wiring than NVFP4 (single per-tensor scale) --- single_shot_inference.py | 86 +++++++++++++++------------------------- 1 file changed, 31 insertions(+), 55 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 369b1d50..31f2f23c 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1306,50 +1306,36 @@ def main(): router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32)) else: eb = all_w.get(f"{pfx}.gate.e_score_correction_bias") - # NVFP4 production GEMM for router gate - # Custom CuTeDSL fused kernel crashes MLIR optimizer, - # so we use Nvfp4Linear (proven production path). - from dsv4.layers.linear import Nvfp4Linear - gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate') + # FP8_E4M3 router gate — quantize weight to FP8, dequantize to BF16, F.linear + # This avoids NVFP4's multi-scale complexity while still using FP8 compression. E = cfg["n_routed_experts"] + gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate') if gate_w is not None and gate_ws is not None: - # Checkpoint has NVFP4 gate weight (N_packed, K_packed) — correct layout - gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev) - gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev) - gate_lin.fp4 = [gate_w_view] - gate_lin.sf = [gate_ws.to(dev)] + # Checkpoint has NVFP4 gate weight — dequantize to BF16 first, then re-quantize to FP8 + from dsv4.ops.quantize import dequantize_nvfp4 ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0 isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0) - gate_lin.gs = [1.0] - gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)] - gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this - gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow - gate_lin.finalize_weights() - router.load_nvfp4_gate(gate_lin) - router.load_weights(e_bias=eb.to(dev, torch.float32)) - if li < 5: print(f" L{li}: NVFP4 router gate (checkpoint)", flush=True) + gsb = 1.0 * ws2_v # global_scale_b = gs * ws2 + gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32) + gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H) + gate_bf16 = gate_bf16.T.contiguous() # (H, E) for W_gate else: - # BF16 gate weight: quantize to NVFP4 + # BF16 gate weight from checkpoint gw = all_w.get(f"{pfx}.gate.weight") - if gw is not None: - g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous() - g_bf16 = g_bf16.bfloat16().to(dev) - from dsv4.ops.quantize import quantize_to_nvfp4 - g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16) - gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev) - gate_lin.fp4 = [g_fp4] - gate_lin.sf = [g_sf] - gate_lin.gs = [g_gs] - gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)] - gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides - gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow - gate_lin.finalize_weights() - router.load_nvfp4_gate(gate_lin) - router.load_weights(e_bias=eb.to(dev, torch.float32)) - if li < 5: print(f" L{li}: NVFP4 router gate (quantized, gs={g_gs:.6f})", flush=True) - else: - router.load_weights(e_bias=eb.to(dev, torch.float32)) - router.load_weights(e_bias=eb.to(dev, torch.float32)) + gate_bf16 = gw.bfloat16().to(dev) + if gate_bf16.shape[0] != H: + gate_bf16 = gate_bf16.T.contiguous() # ensure (H, E) + # Quantize to FP8_E4M3: scale = amax / 448.0 + gate_amax = gate_bf16.abs().max().float().item() + gate_scale = gate_amax / 448.0 + gate_fp8 = (gate_bf16.float() / gate_scale).to(torch.float8_e4m3fn) + # Dequantize back to BF16 for F.linear (FP8 round-trip ~0.9999 cos) + gate_dequant = gate_fp8.to(torch.bfloat16) * gate_scale + router.W_gate = gate_dequant.contiguous() # (H, E) for F.linear(x, W_gate.T) + # No gate_lin — force BF16 dispatch path + router.gate_lin = None + router.load_weights(e_bias=eb.to(dev, torch.float32)) + if li < 5: print(f" L{li}: FP8_E4M3 router gate (scale={gate_scale:.6f}, amax={gate_amax:.4f})", flush=True) router.finalize_weights(); routers[li] = router moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H, @@ -1397,21 +1383,11 @@ def main(): torch.cuda.set_device(0) embed_w = all_w.get("model.embed_tokens.weight") embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0')) - # lm_head: NVFP4 production GEMM + # lm_head: BF16 GEMM (checkpoint weight is BF16, no quantization) lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0') - from dsv4.layers.linear import Nvfp4Linear - lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0') - from dsv4.ops.quantize import quantize_weight_to_nvfp4 - lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous()) - lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()] - lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()] - lm_head_lin.gs = [lm_gs] - lm_head_lin.ws2 = [None] - lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0) - lm_head_lin._use_runtime_gsa = True - lm_head_lin.finalize_weights() - lm_w = None - print(" lm_head: NVFP4 production GEMM") + lm_head_lin = None # Use raw BF16 F.linear for lm_head + lm_w = lm_w_raw # Keep as (V, H) BF16 for F.linear + print(" lm_head: BF16 GEMM (checkpoint weight, no quantization)") final_norm_w = all_w.get("model.norm.weight") if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32) @@ -1660,8 +1636,8 @@ def main(): gl._activation_global_scale = fixed_gsa gl._use_runtime_gsa = False n_fixed += 1 - # lm_head - if hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa: + # lm_head (BF16 — no gsa needed) + if lm_head_lin is not None and hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa: fixed_gsa = lm_head_lin._gsa_buf.item() lm_head_lin._activation_global_scale = fixed_gsa lm_head_lin._use_runtime_gsa = False @@ -1669,7 +1645,7 @@ def main(): print(f" Warmup gsa: fixed {n_fixed} projection gsa values from step 0 (MoE/SE keep runtime gsa)", flush=True) x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :] if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w) - logits = lm_head_lin(x_out) + logits = torch.nn.functional.linear(x_out, lm_w) if lm_head_lin is None else lm_head_lin(x_out) if profile: torch.cuda.synchronize() t_lm = time.perf_counter() # Check thinking start token logit on first step