Fix dense router BF16 dispatch for CUDA graph capture

- Run GEMM in BF16 (not FP32) during graph capture — Blackwell tensor cores
  handle BF16 natively; FP32 GEMM triggers cudaErrorStreamCaptureUnsupported
- Pre-transpose W_gate to [E, H] at load time — avoids .T view during capture
- Convert only logits output to FP32 for sqrt(softplus) numerical stability
- This fixes the graph capture failure at layer 0 Graph B
This commit is contained in:
2026-06-04 05:50:13 +00:00
parent b4a59d0940
commit e46b615873
2 changed files with 11 additions and 3 deletions

View File

@@ -18,7 +18,7 @@ import torch
def dense_router_dispatch(
hidden_states: torch.Tensor, # [N, hidden_size] BF16
W_gate: torch.Tensor, # [hidden_size, num_experts] BF16
W_gate: torch.Tensor, # [num_experts, hidden_size] BF16 (pre-transposed for F.linear)
e_bias: torch.Tensor, # [num_experts] FP32
routed_scaling_factor: float,
top_k: int,
@@ -29,8 +29,15 @@ def dense_router_dispatch(
BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores),
then fused activation + top-k via the CUDA kernel.
CUDA-graph-compatible: W_gate must be pre-transposed to [E, H] BF16
so no .T or .float() calls happen during capture. The GEMM runs in BF16
(Blackwell tensor cores handle BF16 natively). Only the output logits
are cast to FP32 for the sqrt(softplus) activation.
"""
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float())
# BF16 GEMM — W_gate is pre-transposed to [E, H] for F.linear
logits_bf16 = torch.nn.functional.linear(hidden_states, W_gate)
logits = logits_bf16.float() # BF16 → FP32 for sqrt(softplus) numerical stability
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
run_fused_activation_topk(
logits, e_bias, routed_scaling_factor, top_k,

View File

@@ -141,7 +141,8 @@ class Router:
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
if W_gate is not None:
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
# Pre-transpose to [E, H] for F.linear — avoids .T during graph capture
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16).T.contiguous()
# gate_lin is set separately via load_nvfp4_gate()
else: # hash
if hash_lut is None: