Revert "[LoRA] Support FusedMoE LoRA Triton kernel for mxfp4 (#28971)" (#29697)

Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
Huamin Li
2025-11-28 15:26:52 -08:00
committed by GitHub
parent a51f4186f2
commit 3fd1fb0b60
4 changed files with 12 additions and 441 deletions

View File

@@ -5,7 +5,6 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
@@ -377,148 +376,3 @@ class OAITritonExperts(BaseOAITritonExperts):
intermediate_cache=workspace2,
a1q_scale=a1q_scale,
)
class UnfusedOAITritonExperts(BaseOAITritonExperts):
"""
A Triton based MoE expert class that operates on expert standard
format and explicitly keeps the activation and reduction (moe_sum) steps
unfused from the matmul_ogs kernel. This exposes injection points
for activation and moe_sum.
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_chunking(self) -> bool:
return True
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
workspace1 = (M * topk, N // 2)
workspace2 = (M * topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor):
ops.moe_sum(input, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
if self.quant_config is None:
self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
if expert_map is not None:
topk_ids = expert_map[topk_ids]
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
routing_data, gather_indx, scatter_indx = self._make_routing_data(
topk_ids, topk_weights, local_num_experts
)
topk = topk_ids.size(1)
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert (
self.quant_config.w1_bias is None
or self.quant_config.w1_bias.dtype == torch.float32
)
assert (
self.quant_config.w2_bias is None
or self.quant_config.w2_bias.dtype == torch.float32
)
# Shape check, only check non-mxfp4
assert hidden_states.ndim == 2
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
batch_dim = 1
M, K = hidden_states.shape
E, _, N = w1.shape
if global_num_experts == -1:
global_num_experts = E
# Note that the output tensor might be in workspace13
intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N))
intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K))
intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2))
gammas = routing_data.gate_scal if routing_data else None
matmul_ogs(
hidden_states,
w1,
self.quant_config.w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=self.quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=None,
y=intermediate_cache1,
)
self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)
# matmul_ogs grouped reduction fuse sum across multiple experts:
# y[dst_ind // n_expts_act, :] += x[src_ind, :]
# Need to set n_expts_act to 1 to unfuse moe_sum
routing_data.n_expts_act = 1
matmul_ogs(
intermediate_cache2,
w2,
self.quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=self.quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=intermediate_cache3,
)
self.moe_sum(intermediate_cache3.view(-1, topk, K), output)

View File

@@ -30,7 +30,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
@@ -84,21 +83,8 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
if not current_platform.is_cuda():
return Mxfp4Backend.NONE
# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
)
if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
return Mxfp4Backend.MARLIN
logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
return Mxfp4Backend.TRITON
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
return Mxfp4Backend.MARLIN
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
@@ -868,8 +854,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
return MarlinExperts(self.moe_quant_config)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
if self.moe.is_lora_enabled:
return UnfusedOAITritonExperts(self.moe_quant_config)
return OAITritonExperts(self.moe_quant_config)
else:
raise NotImplementedError(