diff --git a/dsv4/kernels/router/dense_router_decode.py b/dsv4/kernels/router/dense_router_decode.py index b9a6f80a..d4ae5813 100644 --- a/dsv4/kernels/router/dense_router_decode.py +++ b/dsv4/kernels/router/dense_router_decode.py @@ -18,7 +18,7 @@ import torch def dense_router_dispatch( hidden_states: torch.Tensor, # [N, hidden_size] BF16 - W_gate: torch.Tensor, # [hidden_size, num_experts] BF16 + W_gate: torch.Tensor, # [num_experts, hidden_size] BF16 (pre-transposed for F.linear) e_bias: torch.Tensor, # [num_experts] FP32 routed_scaling_factor: float, top_k: int, @@ -29,8 +29,15 @@ def dense_router_dispatch( BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores), then fused activation + top-k via the CUDA kernel. + + CUDA-graph-compatible: W_gate must be pre-transposed to [E, H] BF16 + so no .T or .float() calls happen during capture. The GEMM runs in BF16 + (Blackwell tensor cores handle BF16 natively). Only the output logits + are cast to FP32 for the sqrt(softplus) activation. """ - logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float()) + # BF16 GEMM — W_gate is pre-transposed to [E, H] for F.linear + logits_bf16 = torch.nn.functional.linear(hidden_states, W_gate) + logits = logits_bf16.float() # BF16 → FP32 for sqrt(softplus) numerical stability from dsv4.kernels.router._activation_topk import run_fused_activation_topk run_fused_activation_topk( logits, e_bias, routed_scaling_factor, top_k, diff --git a/dsv4/layers/router.py b/dsv4/layers/router.py index fbdf3db6..798b3c05 100644 --- a/dsv4/layers/router.py +++ b/dsv4/layers/router.py @@ -141,7 +141,8 @@ class Router: f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)" self.e_bias = e_bias.to(device=self.device, dtype=torch.float32) if W_gate is not None: - self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16) + # Pre-transpose to [E, H] for F.linear — avoids .T during graph capture + self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16).T.contiguous() # gate_lin is set separately via load_nvfp4_gate() else: # hash if hash_lut is None: