From e46b61587383951bf9971636d2017c68ee2df6fc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 4 Jun 2026 05:50:13 +0000 Subject: [PATCH] Fix dense router BF16 dispatch for CUDA graph capture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Run GEMM in BF16 (not FP32) during graph capture — Blackwell tensor cores handle BF16 natively; FP32 GEMM triggers cudaErrorStreamCaptureUnsupported - Pre-transpose W_gate to [E, H] at load time — avoids .T view during capture - Convert only logits output to FP32 for sqrt(softplus) numerical stability - This fixes the graph capture failure at layer 0 Graph B --- dsv4/kernels/router/dense_router_decode.py | 11 +++++++++-- dsv4/layers/router.py | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/router/dense_router_decode.py b/dsv4/kernels/router/dense_router_decode.py index b9a6f80a..d4ae5813 100644 --- a/dsv4/kernels/router/dense_router_decode.py +++ b/dsv4/kernels/router/dense_router_decode.py @@ -18,7 +18,7 @@ import torch def dense_router_dispatch( hidden_states: torch.Tensor, # [N, hidden_size] BF16 - W_gate: torch.Tensor, # [hidden_size, num_experts] BF16 + W_gate: torch.Tensor, # [num_experts, hidden_size] BF16 (pre-transposed for F.linear) e_bias: torch.Tensor, # [num_experts] FP32 routed_scaling_factor: float, top_k: int, @@ -29,8 +29,15 @@ def dense_router_dispatch( BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores), then fused activation + top-k via the CUDA kernel. + + CUDA-graph-compatible: W_gate must be pre-transposed to [E, H] BF16 + so no .T or .float() calls happen during capture. The GEMM runs in BF16 + (Blackwell tensor cores handle BF16 natively). Only the output logits + are cast to FP32 for the sqrt(softplus) activation. """ - logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float()) + # BF16 GEMM — W_gate is pre-transposed to [E, H] for F.linear + logits_bf16 = torch.nn.functional.linear(hidden_states, W_gate) + logits = logits_bf16.float() # BF16 → FP32 for sqrt(softplus) numerical stability from dsv4.kernels.router._activation_topk import run_fused_activation_topk run_fused_activation_topk( logits, e_bias, routed_scaling_factor, top_k, diff --git a/dsv4/layers/router.py b/dsv4/layers/router.py index fbdf3db6..798b3c05 100644 --- a/dsv4/layers/router.py +++ b/dsv4/layers/router.py @@ -141,7 +141,8 @@ class Router: f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)" self.e_bias = e_bias.to(device=self.device, dtype=torch.float32) if W_gate is not None: - self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16) + # Pre-transpose to [E, H] for F.linear — avoids .T during graph capture + self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16).T.contiguous() # gate_lin is set separately via load_nvfp4_gate() else: # hash if hash_lut is None: