[ModelBash][DSV3] Add TRTLLM DSV3 Router GEMM kernel (6% B1 Speedup) (#34302)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -221,6 +221,73 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DeepSeekV2Gate(ReplicatedLinear):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
n_experts: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
assert quant_config is None
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
n_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
# Unquantized only, will be called "weight".
|
||||
assert hasattr(self, "weight")
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
(9, 0)
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
self.allow_dsv3_router_gemm = (
|
||||
current_platform.is_cuda()
|
||||
and is_hopper_or_blackwell
|
||||
and n_experts in SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
self._out_dtype: torch.dtype | None = None
|
||||
|
||||
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
|
||||
"""
|
||||
Set out dtype for the router logits. This is needed after
|
||||
__init__, b/c we need to check if the trtllm kernel is
|
||||
selected before we decide between bf16 and fp32.
|
||||
"""
|
||||
|
||||
if self._out_dtype is not None:
|
||||
raise ValueError("out_dtype has already been set")
|
||||
else:
|
||||
self._out_dtype = out_dtype
|
||||
|
||||
@property
|
||||
def out_dtype(self) -> torch.dtype:
|
||||
if self._out_dtype is None:
|
||||
raise ValueError("out_dtype has not been set yet")
|
||||
return self._out_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""
|
||||
Use specialized GEMM for low batch size for DSV3 and KIMI.
|
||||
"""
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
return ops.dsv3_router_gemm(
|
||||
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
|
||||
), None
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -249,10 +316,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
self.gate = DeepSeekV2Gate(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
@@ -325,6 +391,13 @@ class DeepseekV2MoE(nn.Module):
|
||||
else None,
|
||||
)
|
||||
|
||||
# NOTE(rob): this is a hack until we finish off the PR for
|
||||
# merging TRTLLM kernels into the MK framework. Then we can
|
||||
# query the MonolithicMK for the expected router logits.
|
||||
self.gate.set_out_dtype(
|
||||
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
Reference in New Issue
Block a user