Add try/except around fused NVFP4 gate loading with error reporting
If the fused kernel path fails, fall back to BF16 cuBLAS instead of crashing. This lets us see the actual error and continue testing.
This commit is contained in:
@@ -713,19 +713,20 @@ def main():
|
||||
gw = gw.bfloat16().to(dev)
|
||||
# Quantize BF16 → NVFP4 for fused router kernel
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
# gsb (weight global scale) = gw_gs from quantization
|
||||
# gsa (activation global scale) = 1.0 (applied during activation quantization inside kernel)
|
||||
# Actually: gsa is passed to quantize_activation_nvfp4 inside run_nvfp4_fused_router
|
||||
# We need to compute the correct gsa. For NVFP4, gsa = 1/(max_val * 448)
|
||||
# But since activation is quantized at runtime, gsa = input_scale from Nvfp4Linear = 1/(6*448)
|
||||
router.load_nvfp4_fused_gate(
|
||||
gate_weight=gw_fp4,
|
||||
gate_weight_scale=gw_sf,
|
||||
gate_ws2=torch.tensor([gw_gs], device=dev, dtype=torch.float32), # gsb = weight global scale
|
||||
gate_input_scale=torch.tensor([1.0 / (6.0 * 448.0)], device=dev, dtype=torch.float32), # gsa = activation global scale
|
||||
)
|
||||
try:
|
||||
gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.load_nvfp4_fused_gate(
|
||||
gate_weight=gw_fp4,
|
||||
gate_weight_scale=gw_sf,
|
||||
gate_ws2=torch.tensor([gw_gs], device=dev, dtype=torch.float32),
|
||||
gate_input_scale=torch.tensor([1.0 / (6.0 * 448.0)], device=dev, dtype=torch.float32),
|
||||
)
|
||||
if li < 5: print(f" L{li}: Fused NVFP4 gate OK (gs={gw_gs:.6f})", flush=True)
|
||||
except Exception as e:
|
||||
print(f" L{li}: Fused NVFP4 gate FAILED: {e}", flush=True)
|
||||
import traceback; traceback.print_exc()
|
||||
router.load_weights(W_gate=gw, e_bias=eb.to(dev, torch.float32))
|
||||
else:
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
Reference in New Issue
Block a user