Switch lm_head to BF16 + router gate to FP8_E4M3

lm_head: BF16 F.linear (checkpoint weight is BF16, no quantization)
Router gate: FP8_E4M3 quantize→dequantize round-trip, then F.linear
- Dequantize NVFP4 checkpoint weights to BF16 first
- Quantize to FP8_E4M3 (scale = amax/448)
- Dequantize back to BF16 for F.linear
- Uses BF16 dispatch path in dense_router_dispatch
- Simpler scale wiring than NVFP4 (single per-tensor scale)
This commit is contained in:
2026-06-03 14:10:28 +00:00
parent 7901470e63
commit 715602c87c

View File

@@ -1306,50 +1306,36 @@ def main():
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
else:
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
# NVFP4 production GEMM for router gate
# Custom CuTeDSL fused kernel crashes MLIR optimizer,
# so we use Nvfp4Linear (proven production path).
from dsv4.layers.linear import Nvfp4Linear
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
# FP8_E4M3 router gate — quantize weight to FP8, dequantize to BF16, F.linear
# This avoids NVFP4's multi-scale complexity while still using FP8 compression.
E = cfg["n_routed_experts"]
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
if gate_w is not None and gate_ws is not None:
# Checkpoint has NVFP4 gate weight (N_packed, K_packed) — correct layout
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)
gate_lin.fp4 = [gate_w_view]
gate_lin.sf = [gate_ws.to(dev)]
# Checkpoint has NVFP4 gate weight — dequantize to BF16 first, then re-quantize to FP8
from dsv4.ops.quantize import dequantize_nvfp4
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
gate_lin.gs = [1.0]
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: NVFP4 router gate (checkpoint)", flush=True)
gsb = 1.0 * ws2_v # global_scale_b = gs * ws2
gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32)
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H)
gate_bf16 = gate_bf16.T.contiguous() # (H, E) for W_gate
else:
# BF16 gate weight: quantize to NVFP4
# BF16 gate weight from checkpoint
gw = all_w.get(f"{pfx}.gate.weight")
if gw is not None:
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
g_bf16 = g_bf16.bfloat16().to(dev)
from dsv4.ops.quantize import quantize_to_nvfp4
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_lin.fp4 = [g_fp4]
gate_lin.sf = [g_sf]
gate_lin.gs = [g_gs]
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: NVFP4 router gate (quantized, gs={g_gs:.6f})", flush=True)
else:
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.load_weights(e_bias=eb.to(dev, torch.float32))
gate_bf16 = gw.bfloat16().to(dev)
if gate_bf16.shape[0] != H:
gate_bf16 = gate_bf16.T.contiguous() # ensure (H, E)
# Quantize to FP8_E4M3: scale = amax / 448.0
gate_amax = gate_bf16.abs().max().float().item()
gate_scale = gate_amax / 448.0
gate_fp8 = (gate_bf16.float() / gate_scale).to(torch.float8_e4m3fn)
# Dequantize back to BF16 for F.linear (FP8 round-trip ~0.9999 cos)
gate_dequant = gate_fp8.to(torch.bfloat16) * gate_scale
router.W_gate = gate_dequant.contiguous() # (H, E) for F.linear(x, W_gate.T)
# No gate_lin — force BF16 dispatch path
router.gate_lin = None
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: FP8_E4M3 router gate (scale={gate_scale:.6f}, amax={gate_amax:.4f})", flush=True)
router.finalize_weights(); routers[li] = router
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
@@ -1397,21 +1383,11 @@ def main():
torch.cuda.set_device(0)
embed_w = all_w.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
# lm_head: NVFP4 production GEMM
# lm_head: BF16 GEMM (checkpoint weight is BF16, no quantization)
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
from dsv4.layers.linear import Nvfp4Linear
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
from dsv4.ops.quantize import quantize_weight_to_nvfp4
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous())
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()]
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()]
lm_head_lin.gs = [lm_gs]
lm_head_lin.ws2 = [None]
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
lm_head_lin._use_runtime_gsa = True
lm_head_lin.finalize_weights()
lm_w = None
print(" lm_head: NVFP4 production GEMM")
lm_head_lin = None # Use raw BF16 F.linear for lm_head
lm_w = lm_w_raw # Keep as (V, H) BF16 for F.linear
print(" lm_head: BF16 GEMM (checkpoint weight, no quantization)")
final_norm_w = all_w.get("model.norm.weight")
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
@@ -1660,8 +1636,8 @@ def main():
gl._activation_global_scale = fixed_gsa
gl._use_runtime_gsa = False
n_fixed += 1
# lm_head
if hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
# lm_head (BF16 — no gsa needed)
if lm_head_lin is not None and hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
fixed_gsa = lm_head_lin._gsa_buf.item()
lm_head_lin._activation_global_scale = fixed_gsa
lm_head_lin._use_runtime_gsa = False
@@ -1669,7 +1645,7 @@ def main():
print(f" Warmup gsa: fixed {n_fixed} projection gsa values from step 0 (MoE/SE keep runtime gsa)", flush=True)
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
logits = lm_head_lin(x_out)
logits = torch.nn.functional.linear(x_out, lm_w) if lm_head_lin is None else lm_head_lin(x_out)
if profile: torch.cuda.synchronize()
t_lm = time.perf_counter()
# Check thinking start token logit on first step