From ffa7842b58147ce9c87c8eb127c0d5f5e92560ae Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 4 Jun 2026 04:49:08 +0000 Subject: [PATCH] Fix dense router: run GEMM in BF16, convert to FP32 only for activation hidden_states.float() and gate_bf16.T.float() create new FP32 tensors during CUDA graph capture, which is not graph-capturable. Fix: run the linear in BF16 (Blackwell tensor cores handle BF16 natively), then convert only the output logits to FP32 for numerical stability in sqrt(softplus). The single logits.float() is graph-capturable because it's a unary op with a pre-existing output buffer. --- dsv4/kernels/router/dense_router_decode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dsv4/kernels/router/dense_router_decode.py b/dsv4/kernels/router/dense_router_decode.py index df840930..b9a6f80a 100644 --- a/dsv4/kernels/router/dense_router_decode.py +++ b/dsv4/kernels/router/dense_router_decode.py @@ -97,7 +97,8 @@ def dense_router_dispatch_nvfp4_fused( # Decode the gate_weight from NVFP4 to BF16 for cuBLAS from dsv4.ops.quantize import dequantize_nvfp4 gate_bf16 = dequantize_nvfp4(gate_weight, gate_weight_scale, gate_ws2) - logits = torch.nn.functional.linear(hidden_states.float(), gate_bf16.T.float()) + logits = torch.nn.functional.linear(hidden_states, gate_bf16.T) + logits = logits.float() # BF16 → FP32 for numerical stability in sqrt(softplus) run_fused_activation_topk( logits, e_bias, routed_scaling_factor, top_k,