[ModelBash][DSV3] Add TRTLLM DSV3 Router GEMM kernel (6% B1 Speedup) (#34302)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-02-23 09:02:26 -05:00
committed by GitHub
parent b1b5e045df
commit 8435b2e049
9 changed files with 915 additions and 3 deletions

View File

@@ -221,6 +221,73 @@ class DeepseekV2MLP(nn.Module):
return x
class DeepSeekV2Gate(ReplicatedLinear):
def __init__(
self,
hidden_size: int,
n_experts: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
assert quant_config is None
super().__init__(
hidden_size,
n_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)
# Unquantized only, will be called "weight".
assert hasattr(self, "weight")
is_hopper_or_blackwell = current_platform.is_device_capability(
(9, 0)
) or current_platform.is_device_capability_family(100)
SUPPORTED_NUM_EXPERTS = [256, 384]
SUPPORTED_HIDDEN_SIZES = [7168]
self.allow_dsv3_router_gemm = (
current_platform.is_cuda()
and is_hopper_or_blackwell
and n_experts in SUPPORTED_NUM_EXPERTS
and hidden_size in SUPPORTED_HIDDEN_SIZES
)
self._out_dtype: torch.dtype | None = None
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
"""
Set out dtype for the router logits. This is needed after
__init__, b/c we need to check if the trtllm kernel is
selected before we decide between bf16 and fp32.
"""
if self._out_dtype is not None:
raise ValueError("out_dtype has already been set")
else:
self._out_dtype = out_dtype
@property
def out_dtype(self) -> torch.dtype:
if self._out_dtype is None:
raise ValueError("out_dtype has not been set yet")
return self._out_dtype
def forward(
self,
x: torch.Tensor,
) -> tuple[torch.Tensor, None]:
"""
Use specialized GEMM for low batch size for DSV3 and KIMI.
"""
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
return ops.dsv3_router_gemm(
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
), None
else:
return super().forward(x)
class DeepseekV2MoE(nn.Module):
def __init__(
self,
@@ -249,10 +316,9 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
self.gate = ReplicatedLinear(
self.gate = DeepSeekV2Gate(
config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
@@ -325,6 +391,13 @@ class DeepseekV2MoE(nn.Module):
else None,
)
# NOTE(rob): this is a hack until we finish off the PR for
# merging TRTLLM kernels into the MK framework. Then we can
# query the MonolithicMK for the expected router logits.
self.gate.set_out_dtype(
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)