[DP/EP][GPTOSS] Use triton matmul-ogs kernels for GPTOSS DP/EP (#24588)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-09-23 00:01:09 -04:00
committed by GitHub
parent fafbe11af4
commit e8db44f883
6 changed files with 275 additions and 76 deletions

View File

@@ -13,7 +13,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
@@ -578,9 +581,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = (self.moe.use_pplx_kernels
or self.moe.use_deepep_ll_kernels)
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
@@ -640,16 +648,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
return mxfp4_w4a4_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
return mxfp4_w4a4_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def select_gemm_impl(
self,
@@ -661,6 +674,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP")
else:
assert self.moe_quant_config is not None
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
# B200 code-path
@@ -671,13 +685,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
}
assert self.moe_quant_config is not None
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs)
else:
# Use matmul_ogs from triton_kernels here!
raise NotImplementedError(
"Mxfp4 does not support non-batched experts format for EP")
return OAITritonExperts(self.moe_quant_config)
def _route_and_experts(
self,
@@ -722,10 +733,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
w13_weight = (self.w13_weight_triton_tensor
if layer.w13_weight is None else layer.w13_weight)
w2_weight = (self.w2_weight_triton_tensor
if layer.w2_weight is None else layer.w2_weight)
assert all([w is not None for w in [w13_weight, w2_weight]])
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1=w13_weight,
w2=w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,