[GPTOSS][DP/EP][Marlin] Enable GPTOSS DP/EP using Marlin kernels (#25488)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
767cbb011d
commit
7ef40bb983
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
|
||||
mxfp4_w4a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
@@ -92,7 +93,7 @@ def get_mxfp4_backend():
|
||||
"Please `pip install vllm[flashinfer]` for best results.")
|
||||
|
||||
# If FlashInfer is not available, try either Marlin or Triton
|
||||
if current_platform.get_device_capability(
|
||||
if envs.VLLM_MXFP4_USE_MARLIN or current_platform.get_device_capability(
|
||||
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
|
||||
"2.8.0"):
|
||||
logger.info_once("Using Marlin backend")
|
||||
@@ -646,9 +647,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
return None
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
@@ -690,6 +695,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
}
|
||||
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
||||
**kwargs)
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.MARLIN):
|
||||
return MarlinExperts(self.moe_quant_config)
|
||||
else:
|
||||
return OAITritonExperts(self.moe_quant_config)
|
||||
|
||||
@@ -782,6 +789,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self._route_and_experts(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@@ -815,29 +845,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation=activation,
|
||||
expert_map=expert_map)
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self._route_and_experts(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||
custom_routing_function, e_score_correction_bias,
|
||||
|
||||
Reference in New Issue
Block a user