DP/EP Support for gpt-oss with deepep-ht comm kernel on SM100 (#23608)
This commit is contained in:
@@ -10,6 +10,8 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -445,6 +447,91 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
if (prepare_finalize.activation_format ==
|
||||
mk.FusedMoEActivationFormat.BatchedExperts):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP")
|
||||
else:
|
||||
if should_use_flashinfer_mxfp4():
|
||||
# B200 code-path
|
||||
kwargs = {
|
||||
"gemm1_alpha": layer.gemm1_alpha,
|
||||
"gemm1_beta": layer.gemm1_beta,
|
||||
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
||||
"w13_bias": layer.w13_bias,
|
||||
"w2_bias": layer.w2_bias,
|
||||
"max_capture_size": self.max_capture_size,
|
||||
}
|
||||
return TrtLlmGenExperts(moe, **kwargs)
|
||||
else:
|
||||
# Use matmul_ogs from triton_kernels here!
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support non-batched experts format for EP")
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -503,6 +590,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation=activation,
|
||||
expert_map=expert_map)
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self._route_and_experts(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||
custom_routing_function, e_score_correction_bias,
|
||||
|
||||
Reference in New Issue
Block a user