Fix dense router: run GEMM in BF16, convert to FP32 only for activation

hidden_states.float() and gate_bf16.T.float() create new FP32 tensors
during CUDA graph capture, which is not graph-capturable.

Fix: run the linear in BF16 (Blackwell tensor cores handle BF16 natively),
then convert only the output logits to FP32 for numerical stability
in sqrt(softplus). The single logits.float() is graph-capturable
because it's a unary op with a pre-existing output buffer.
This commit is contained in:
2026-06-04 04:49:08 +00:00
parent 119e6d471e
commit ffa7842b58

View File

@@ -97,7 +97,8 @@ def dense_router_dispatch_nvfp4_fused(
# Decode the gate_weight from NVFP4 to BF16 for cuBLAS
from dsv4.ops.quantize import dequantize_nvfp4
gate_bf16 = dequantize_nvfp4(gate_weight, gate_weight_scale, gate_ws2)
logits = torch.nn.functional.linear(hidden_states.float(), gate_bf16.T.float())
logits = torch.nn.functional.linear(hidden_states, gate_bf16.T)
logits = logits.float() # BF16 → FP32 for numerical stability in sqrt(softplus)
run_fused_activation_topk(
logits, e_bias, routed_scaling_factor, top_k,