[DP/EP][GPTOSS] Use triton matmul-ogs kernels for GPTOSS DP/EP (#24588)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
fafbe11af4
commit
e8db44f883
@@ -13,7 +13,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
|
||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
|
||||
mxfp4_w4a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
@@ -578,9 +581,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
|
||||
# (stored in self.fused_experts) to determine if the MoE has a
|
||||
# batched activation format. As self.fused_experts is not
|
||||
# initialized at this point, we resort to checking the MoE config
|
||||
# directly.
|
||||
is_batched_moe = (self.moe.use_pplx_kernels
|
||||
or self.moe.use_deepep_ll_kernels)
|
||||
if is_batched_moe:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
@@ -640,16 +648,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
else:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
@@ -661,6 +674,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP")
|
||||
else:
|
||||
assert self.moe_quant_config is not None
|
||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
# B200 code-path
|
||||
@@ -671,13 +685,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# TODO(bnell): part of quant_config
|
||||
"max_capture_size": self.max_capture_size,
|
||||
}
|
||||
assert self.moe_quant_config is not None
|
||||
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
||||
**kwargs)
|
||||
else:
|
||||
# Use matmul_ogs from triton_kernels here!
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support non-batched experts format for EP")
|
||||
return OAITritonExperts(self.moe_quant_config)
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
@@ -722,10 +733,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
w13_weight = (self.w13_weight_triton_tensor
|
||||
if layer.w13_weight is None else layer.w13_weight)
|
||||
w2_weight = (self.w2_weight_triton_tensor
|
||||
if layer.w2_weight is None else layer.w2_weight)
|
||||
assert all([w is not None for w in [w13_weight, w2_weight]])
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1=w13_weight,
|
||||
w2=w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
|
||||
Reference in New Issue
Block a user