Switch lm_head to BF16 + router gate to FP8_E4M3
lm_head: BF16 F.linear (checkpoint weight is BF16, no quantization) Router gate: FP8_E4M3 quantize→dequantize round-trip, then F.linear - Dequantize NVFP4 checkpoint weights to BF16 first - Quantize to FP8_E4M3 (scale = amax/448) - Dequantize back to BF16 for F.linear - Uses BF16 dispatch path in dense_router_dispatch - Simpler scale wiring than NVFP4 (single per-tensor scale)
This commit is contained in:
@@ -1306,50 +1306,36 @@ def main():
|
||||
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
|
||||
# NVFP4 production GEMM for router gate
|
||||
# Custom CuTeDSL fused kernel crashes MLIR optimizer,
|
||||
# so we use Nvfp4Linear (proven production path).
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||||
# FP8_E4M3 router gate — quantize weight to FP8, dequantize to BF16, F.linear
|
||||
# This avoids NVFP4's multi-scale complexity while still using FP8 compression.
|
||||
E = cfg["n_routed_experts"]
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
# Checkpoint has NVFP4 gate weight (N_packed, K_packed) — correct layout
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)
|
||||
gate_lin.fp4 = [gate_w_view]
|
||||
gate_lin.sf = [gate_ws.to(dev)]
|
||||
# Checkpoint has NVFP4 gate weight — dequantize to BF16 first, then re-quantize to FP8
|
||||
from dsv4.ops.quantize import dequantize_nvfp4
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]
|
||||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: NVFP4 router gate (checkpoint)", flush=True)
|
||||
gsb = 1.0 * ws2_v # global_scale_b = gs * ws2
|
||||
gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32)
|
||||
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H)
|
||||
gate_bf16 = gate_bf16.T.contiguous() # (H, E) for W_gate
|
||||
else:
|
||||
# BF16 gate weight: quantize to NVFP4
|
||||
# BF16 gate weight from checkpoint
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [g_fp4]
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: NVFP4 router gate (quantized, gs={g_gs:.6f})", flush=True)
|
||||
else:
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
gate_bf16 = gw.bfloat16().to(dev)
|
||||
if gate_bf16.shape[0] != H:
|
||||
gate_bf16 = gate_bf16.T.contiguous() # ensure (H, E)
|
||||
# Quantize to FP8_E4M3: scale = amax / 448.0
|
||||
gate_amax = gate_bf16.abs().max().float().item()
|
||||
gate_scale = gate_amax / 448.0
|
||||
gate_fp8 = (gate_bf16.float() / gate_scale).to(torch.float8_e4m3fn)
|
||||
# Dequantize back to BF16 for F.linear (FP8 round-trip ~0.9999 cos)
|
||||
gate_dequant = gate_fp8.to(torch.bfloat16) * gate_scale
|
||||
router.W_gate = gate_dequant.contiguous() # (H, E) for F.linear(x, W_gate.T)
|
||||
# No gate_lin — force BF16 dispatch path
|
||||
router.gate_lin = None
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: FP8_E4M3 router gate (scale={gate_scale:.6f}, amax={gate_amax:.4f})", flush=True)
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
@@ -1397,21 +1383,11 @@ def main():
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||||
# lm_head: NVFP4 production GEMM
|
||||
# lm_head: BF16 GEMM (checkpoint weight is BF16, no quantization)
|
||||
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4
|
||||
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous())
|
||||
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()]
|
||||
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()]
|
||||
lm_head_lin.gs = [lm_gs]
|
||||
lm_head_lin.ws2 = [None]
|
||||
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
lm_head_lin._use_runtime_gsa = True
|
||||
lm_head_lin.finalize_weights()
|
||||
lm_w = None
|
||||
print(" lm_head: NVFP4 production GEMM")
|
||||
lm_head_lin = None # Use raw BF16 F.linear for lm_head
|
||||
lm_w = lm_w_raw # Keep as (V, H) BF16 for F.linear
|
||||
print(" lm_head: BF16 GEMM (checkpoint weight, no quantization)")
|
||||
final_norm_w = all_w.get("model.norm.weight")
|
||||
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
|
||||
|
||||
@@ -1660,8 +1636,8 @@ def main():
|
||||
gl._activation_global_scale = fixed_gsa
|
||||
gl._use_runtime_gsa = False
|
||||
n_fixed += 1
|
||||
# lm_head
|
||||
if hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
|
||||
# lm_head (BF16 — no gsa needed)
|
||||
if lm_head_lin is not None and hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
|
||||
fixed_gsa = lm_head_lin._gsa_buf.item()
|
||||
lm_head_lin._activation_global_scale = fixed_gsa
|
||||
lm_head_lin._use_runtime_gsa = False
|
||||
@@ -1669,7 +1645,7 @@ def main():
|
||||
print(f" Warmup gsa: fixed {n_fixed} projection gsa values from step 0 (MoE/SE keep runtime gsa)", flush=True)
|
||||
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
|
||||
logits = lm_head_lin(x_out)
|
||||
logits = torch.nn.functional.linear(x_out, lm_w) if lm_head_lin is None else lm_head_lin(x_out)
|
||||
if profile: torch.cuda.synchronize()
|
||||
t_lm = time.perf_counter()
|
||||
# Check thinking start token logit on first step
|
||||
|
||||
Reference in New Issue
Block a user