Fix dense router BF16 dispatch for CUDA graph capture
- Run GEMM in BF16 (not FP32) during graph capture — Blackwell tensor cores handle BF16 natively; FP32 GEMM triggers cudaErrorStreamCaptureUnsupported - Pre-transpose W_gate to [E, H] at load time — avoids .T view during capture - Convert only logits output to FP32 for sqrt(softplus) numerical stability - This fixes the graph capture failure at layer 0 Graph B
This commit is contained in:
@@ -18,7 +18,7 @@ import torch
|
||||
|
||||
def dense_router_dispatch(
|
||||
hidden_states: torch.Tensor, # [N, hidden_size] BF16
|
||||
W_gate: torch.Tensor, # [hidden_size, num_experts] BF16
|
||||
W_gate: torch.Tensor, # [num_experts, hidden_size] BF16 (pre-transposed for F.linear)
|
||||
e_bias: torch.Tensor, # [num_experts] FP32
|
||||
routed_scaling_factor: float,
|
||||
top_k: int,
|
||||
@@ -29,8 +29,15 @@ def dense_router_dispatch(
|
||||
|
||||
BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores),
|
||||
then fused activation + top-k via the CUDA kernel.
|
||||
|
||||
CUDA-graph-compatible: W_gate must be pre-transposed to [E, H] BF16
|
||||
so no .T or .float() calls happen during capture. The GEMM runs in BF16
|
||||
(Blackwell tensor cores handle BF16 natively). Only the output logits
|
||||
are cast to FP32 for the sqrt(softplus) activation.
|
||||
"""
|
||||
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float())
|
||||
# BF16 GEMM — W_gate is pre-transposed to [E, H] for F.linear
|
||||
logits_bf16 = torch.nn.functional.linear(hidden_states, W_gate)
|
||||
logits = logits_bf16.float() # BF16 → FP32 for sqrt(softplus) numerical stability
|
||||
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||||
run_fused_activation_topk(
|
||||
logits, e_bias, routed_scaling_factor, top_k,
|
||||
|
||||
@@ -141,7 +141,8 @@ class Router:
|
||||
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
|
||||
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
|
||||
if W_gate is not None:
|
||||
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
|
||||
# Pre-transpose to [E, H] for F.linear — avoids .T during graph capture
|
||||
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16).T.contiguous()
|
||||
# gate_lin is set separately via load_nvfp4_gate()
|
||||
else: # hash
|
||||
if hash_lut is None:
|
||||
|
||||
Reference in New Issue
Block a user