DP/EP Support for gpt-oss with deepep-ht comm kernel on SM100 (#23608)

This commit is contained in:
Yongye Zhu
2025-08-27 17:33:21 -04:00
committed by GitHub
parent 853c371fc3
commit 082cc07ef8
11 changed files with 365 additions and 12 deletions

View File

@@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl(
@@ -719,10 +720,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
dtype=torch.int64)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import (

View File

@@ -897,6 +897,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

View File

@@ -311,6 +311,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_cutlass_fp8_gemm_impl(
moe,
@@ -1032,6 +1033,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_nvfp4_gemm_impl(
moe,

View File

@@ -10,6 +10,8 @@ from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -445,6 +447,91 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
mk.FusedMoEActivationFormat.BatchedExperts):
raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP")
else:
if should_use_flashinfer_mxfp4():
# B200 code-path
kwargs = {
"gemm1_alpha": layer.gemm1_alpha,
"gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
"w13_bias": layer.w13_bias,
"w2_bias": layer.w2_bias,
"max_capture_size": self.max_capture_size,
}
return TrtLlmGenExperts(moe, **kwargs)
else:
# Use matmul_ogs from triton_kernels here!
raise NotImplementedError(
"Mxfp4 does not support non-batched experts format for EP")
def _route_and_experts(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None
) -> torch.Tensor:
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def apply(
self,
layer: torch.nn.Module,
@@ -503,6 +590,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=activation,
expert_map=expert_map)
if self.fused_experts is not None:
return self._route_and_experts(
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
assert _can_support_mxfp4(
use_grouped_topk, topk_group, num_expert_group, expert_map,
custom_routing_function, e_score_correction_bias,

View File

@@ -66,11 +66,10 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None):
return not (use_grouped_topk or topk_group or num_expert_group
or expert_map or custom_routing_function
or e_score_correction_bias or apply_router_weight_on_input
or scoring_func != "softmax" or activation != "swigluoai"
or expert_load_view or logical_to_physical_map
or logical_replica_count)
or custom_routing_function or e_score_correction_bias
or apply_router_weight_on_input or scoring_func != "softmax"
or activation != "swigluoai" or expert_load_view
or logical_to_physical_map or logical_replica_count)
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
try:
from flashinfer import mxfp8_quantize
except ImportError as err:
raise ImportError("The package `flashinfer` is required to do "
"MX-FP8 quantization. Please install it with" \
"`pip install flashinfer`") from err
return mxfp8_quantize(x, is_sf_swizzled_layout=False)