diff --git a/single_shot_inference.py b/single_shot_inference.py index fc19ce4d..86e25217 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -713,19 +713,20 @@ def main(): gw = gw.bfloat16().to(dev) # Quantize BF16 → NVFP4 for fused router kernel from dsv4.ops.quantize import quantize_to_nvfp4 - gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw) - router.load_weights(e_bias=eb.to(dev, torch.float32)) - # gsb (weight global scale) = gw_gs from quantization - # gsa (activation global scale) = 1.0 (applied during activation quantization inside kernel) - # Actually: gsa is passed to quantize_activation_nvfp4 inside run_nvfp4_fused_router - # We need to compute the correct gsa. For NVFP4, gsa = 1/(max_val * 448) - # But since activation is quantized at runtime, gsa = input_scale from Nvfp4Linear = 1/(6*448) - router.load_nvfp4_fused_gate( - gate_weight=gw_fp4, - gate_weight_scale=gw_sf, - gate_ws2=torch.tensor([gw_gs], device=dev, dtype=torch.float32), # gsb = weight global scale - gate_input_scale=torch.tensor([1.0 / (6.0 * 448.0)], device=dev, dtype=torch.float32), # gsa = activation global scale - ) + try: + gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw) + router.load_weights(e_bias=eb.to(dev, torch.float32)) + router.load_nvfp4_fused_gate( + gate_weight=gw_fp4, + gate_weight_scale=gw_sf, + gate_ws2=torch.tensor([gw_gs], device=dev, dtype=torch.float32), + gate_input_scale=torch.tensor([1.0 / (6.0 * 448.0)], device=dev, dtype=torch.float32), + ) + if li < 5: print(f" L{li}: Fused NVFP4 gate OK (gs={gw_gs:.6f})", flush=True) + except Exception as e: + print(f" L{li}: Fused NVFP4 gate FAILED: {e}", flush=True) + import traceback; traceback.print_exc() + router.load_weights(W_gate=gw, e_bias=eb.to(dev, torch.float32)) else: router.load_weights(e_bias=eb.to(dev, torch.float32)) router.finalize_weights(); routers[li] = router