Fix dense router: run GEMM in BF16, convert to FP32 only for activation
hidden_states.float() and gate_bf16.T.float() create new FP32 tensors during CUDA graph capture, which is not graph-capturable. Fix: run the linear in BF16 (Blackwell tensor cores handle BF16 natively), then convert only the output logits to FP32 for numerical stability in sqrt(softplus). The single logits.float() is graph-capturable because it's a unary op with a pre-existing output buffer.
This commit is contained in:
@@ -97,7 +97,8 @@ def dense_router_dispatch_nvfp4_fused(
|
||||
# Decode the gate_weight from NVFP4 to BF16 for cuBLAS
|
||||
from dsv4.ops.quantize import dequantize_nvfp4
|
||||
gate_bf16 = dequantize_nvfp4(gate_weight, gate_weight_scale, gate_ws2)
|
||||
logits = torch.nn.functional.linear(hidden_states.float(), gate_bf16.T.float())
|
||||
logits = torch.nn.functional.linear(hidden_states, gate_bf16.T)
|
||||
logits = logits.float() # BF16 → FP32 for numerical stability in sqrt(softplus)
|
||||
|
||||
run_fused_activation_topk(
|
||||
logits, e_bias, routed_scaling_factor, top_k,
|
||||
|
||||
Reference in New Issue
Block a user