Add try/except around fused NVFP4 gate loading with error reporting

If the fused kernel path fails, fall back to BF16 cuBLAS instead of
crashing. This lets us see the actual error and continue testing.
This commit is contained in:
2026-06-01 11:08:06 +00:00
parent 5f38430423
commit fbc1e883f2

View File

@@ -713,19 +713,20 @@ def main():
gw = gw.bfloat16().to(dev)
# Quantize BF16 → NVFP4 for fused router kernel
from dsv4.ops.quantize import quantize_to_nvfp4
gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw)
router.load_weights(e_bias=eb.to(dev, torch.float32))
# gsb (weight global scale) = gw_gs from quantization
# gsa (activation global scale) = 1.0 (applied during activation quantization inside kernel)
# Actually: gsa is passed to quantize_activation_nvfp4 inside run_nvfp4_fused_router
# We need to compute the correct gsa. For NVFP4, gsa = 1/(max_val * 448)
# But since activation is quantized at runtime, gsa = input_scale from Nvfp4Linear = 1/(6*448)
router.load_nvfp4_fused_gate(
gate_weight=gw_fp4,
gate_weight_scale=gw_sf,
gate_ws2=torch.tensor([gw_gs], device=dev, dtype=torch.float32), # gsb = weight global scale
gate_input_scale=torch.tensor([1.0 / (6.0 * 448.0)], device=dev, dtype=torch.float32), # gsa = activation global scale
)
try:
gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw)
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.load_nvfp4_fused_gate(
gate_weight=gw_fp4,
gate_weight_scale=gw_sf,
gate_ws2=torch.tensor([gw_gs], device=dev, dtype=torch.float32),
gate_input_scale=torch.tensor([1.0 / (6.0 * 448.0)], device=dev, dtype=torch.float32),
)
if li < 5: print(f" L{li}: Fused NVFP4 gate OK (gs={gw_gs:.6f})", flush=True)
except Exception as e:
print(f" L{li}: Fused NVFP4 gate FAILED: {e}", flush=True)
import traceback; traceback.print_exc()
router.load_weights(W_gate=gw, e_bias=eb.to(dev, torch.float32))
else:
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.finalize_weights(); routers[li] = router