[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm
2025-08-15 14:46:00 -04:00
committed by GitHub
parent 48b01fd4d4
commit 8ad7285ea2
54 changed files with 2010 additions and 1293 deletions

View File

@@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMoEMethod(quant_args_marlin)
return AWQMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
@@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig):
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
return GPTQMarlinMoEMethod(quant_args_marlin)
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:

View File

@@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig):
}
awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict)
return AWQMoEMethod(awq_marlin_config)
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
return None

View File

@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
@@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig):
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMoEMethod(self)
return AWQMoEMethod(self, layer.moe_config)
return None
@classmethod
@@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
def __init__(
self,
quant_config: AWQMarlinConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
@@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `AWQMoEMethod` yet.")
@@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe(
x,
@@ -535,4 +543,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace)
workspace=layer.workspace)

View File

@@ -7,6 +7,7 @@ import torch
from packaging import version
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
@@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
elif isinstance(layer, FusedMoE):
return BitsAndBytesMoEMethod(self)
return BitsAndBytesMoEMethod(self, layer.moe_config)
return None
@@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
quant_config: The BitsAndBytes quantization config.
"""
def __init__(self, quant_config: BitsAndBytesConfig):
def __init__(
self,
quant_config: BitsAndBytesConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
try:
import bitsandbytes
if version.parse(
@@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
raise ImportError("Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer.") from err
self.topk_indices_dtype = None
self.quant_config = quant_config
def create_weights(
@@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(

View File

@@ -11,20 +11,21 @@ from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
@@ -58,6 +59,9 @@ __all__ = [
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def __init_(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
@@ -81,18 +85,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"WNA16MoE is not supported with actorder=group/dynamic."
)
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config)
return CompressedTensorsWNA16MoEMethod(quant_config,
layer.moe_config)
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod()
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
return CompressedTensorsW8A8Fp8MoEMethod(quant_config,
layer.moe_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
layer.moe_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@@ -100,15 +108,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self):
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
super().__init__(moe)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.fused_experts = None # type: ignore[assignment]
self.layer = layer
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@@ -265,19 +274,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False)
def maybe_swap_experts_impl(self, moe_parallel_config):
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
return super().maybe_make_prepare_finalize(moe)
def select_gemm_impl(self, prepare_finalize, moe):
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation."""
assert moe is not None and prepare_finalize is not None
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def apply(
self,
@@ -301,6 +327,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.")
@@ -317,6 +345,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
@@ -340,15 +369,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
# FlashInfer fused experts path
if self.fused_experts is not None:
return flashinfer_fp4_cutlass_moe_forward(
self.fused_experts,
layer,
x,
topk_weights,
topk_ids,
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@@ -376,7 +412,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
device=x.device,
apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype)
@@ -384,15 +419,16 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
self.topk_indices_dtype = None
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy
@@ -429,7 +465,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.weight_quant, self.input_quant)
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
self.fused_experts = None # type: ignore[assignment]
self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@@ -614,25 +649,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
use_batched_format = (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts)
experts: FusedMoEPermuteExpertsUnpermute
num_dispatchers = prepare_finalize.num_dispatchers()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
num_experts,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
num_dispatchers=num_dispatchers,
use_batched_format=use_batched_format,
)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("CutlassBatchedExpertsFp8(%s)",
self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
moe.num_local_experts,
num_dispatchers,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
)
self.disable_expert_map = (num_dispatchers > 1
or not experts.supports_expert_map())
@@ -834,9 +875,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
@@ -934,6 +977,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
@@ -951,7 +996,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
hidden_states=x,
@@ -975,9 +1021,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
@@ -1233,6 +1281,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
@@ -1251,7 +1301,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe(
x,
@@ -1279,9 +1330,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
@@ -1459,6 +1512,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsWNA16MoEMethod` yet.")
@@ -1475,7 +1530,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
x,

View File

@@ -6,7 +6,8 @@ from typing import Any, Callable, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig):
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return ExpertsInt8MoEMethod(self)
return ExpertsInt8MoEMethod(self, layer.moe_config)
return None
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: ExpertsInt8Config):
def __init__(
self,
quant_config: ExpertsInt8Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ExpertsInt8MoEMethod` yet.")
@@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
x,

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
@@ -142,7 +141,7 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
return Fp8MoEMethod(self, layer.moe_config)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
@@ -479,9 +478,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
from vllm.model_executor.layers.fused_moe import fused_experts
def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
@@ -529,15 +527,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@@ -1033,7 +1022,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
elif self.fused_experts is not None:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@@ -1052,6 +1041,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8KVCacheMethod(BaseKVCacheMethod):

View File

@@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -58,7 +59,7 @@ class GGUFConfig(QuantizationConfig):
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self)
return GGUFMoEMethod(self, layer.moe_config)
return None
@@ -445,7 +446,12 @@ class GGUFMoEMethod(FusedMoEMethodBase):
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
def __init__(
self,
quant_config: GGUFConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@@ -525,6 +531,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GGUFMoEMethod` yet.")
@@ -545,7 +553,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,

View File

@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
@@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization."""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
def __init__(
self,
quant_config: GPTQMarlinConfig,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
@@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
@@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe(
x,

View File

@@ -12,7 +12,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
@@ -22,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
@@ -177,7 +179,7 @@ class ModelOptFp8Config(QuantizationConfig):
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
return ModelOptFp8MoEMethod(self, layer.moe_config)
return None
@@ -273,7 +275,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config) -> None:
def __init__(
self,
quant_config: ModelOptFp8Config,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
@@ -454,6 +461,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
@@ -484,6 +493,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
@@ -699,7 +709,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoE(self)
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
return None
@@ -923,10 +933,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_config: NVFP4 Quant Config
"""
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
def __init__(
self,
quant_config: ModelOptNvFp4Config,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
super().__init__(moe)
self.quant_config = quant_config
self.layer = layer
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
@@ -952,27 +969,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
def maybe_swap_experts_impl(
def maybe_make_prepare_finalize(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
return super().maybe_make_prepare_finalize(moe)
# This method update self.fused_experts
# only prepare_finalize is not None call select_gemm_impl
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
# when it's not called(TP case), we still have 2 kernels to use.
def select_gemm_impl(self, prepare_finalize,
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
assert moe is not None and prepare_finalize is not None
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def uses_weight_scale_2_pattern(self) -> bool:
"""
@@ -1362,7 +1387,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe(
@@ -1404,21 +1430,28 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
device=x.device,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
out = flashinfer_fp4_cutlass_moe_forward(
self.fused_experts,
layer,
x,
topk_weights,
topk_ids,
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@@ -7,7 +7,7 @@ import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig):
else:
raise ValueError("moe_wna16 only support gptq and awq.")
elif isinstance(layer, FusedMoE):
return MoeWNA16Method(self)
return MoeWNA16Method(self, layer.moe_config)
return None
@@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
"""
def __init__(self, quant_config: MoeWNA16Config):
def __init__(
self,
quant_config: MoeWNA16Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `MoeWNA16Method` yet.")
@@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp

View File

@@ -82,7 +82,7 @@ class Mxfp4Config(QuantizationConfig):
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.use_marlin = self._should_use_marlin()

View File

@@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE)
@@ -25,6 +26,9 @@ __all__ = [
class QuarkMoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod
def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
@@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase):
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
module.moe_config)
elif quant_config._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
module.moe_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
Any]):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config
self.input_quant = input_config
@@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
@@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
x,
@@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
Any]):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config
self.input_quant = input_config
@@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
@@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
out = fused_experts(
x,

View File

@@ -10,7 +10,8 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig):
if isinstance(layer, LinearBase):
return RTNLinearMethod(self)
elif isinstance(layer, FusedMoE):
return RTNMoEMethod(self)
return RTNMoEMethod(self, layer.moe_config)
return None
@@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase):
class RTNMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: RTNConfig):
def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `RTNMoEMethod` yet.")
@@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size

View File

@@ -3,33 +3,30 @@
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
from __future__ import annotations
from typing import Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.platforms import current_platform
logger = init_logger(__name__)
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_kernel",
"flashinfer_fp4_cutlass_moe_forward",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda()
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.is_device_capability(100))
@@ -49,105 +46,33 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
dim=dim).contiguous())
def build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel:
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel"""
experts = FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe_parallel_config.dp_size > 1,
ep_rank=moe_parallel_config.ep_rank,
ep_size=moe_parallel_config.ep_size,
tp_rank=moe_parallel_config.tp_rank,
tp_size=moe_parallel_config.tp_size,
)
logger.debug_once("FlashInferExperts (util)")
return mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8),
experts,
)
def flashinfer_fp4_cutlass_moe_forward(
fused_experts: mk.FusedMoEModularKernel,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight,
layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!")
a1_gscale = layer.w13_input_scale_quant
a2_gscale = layer.w2_input_scale_quant
extra_expert_args = {
"g1_alphas": layer.g1_alphas,
"g2_alphas": layer.g2_alphas,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
"a1_gscale": a1_gscale,
"a2_gscale": a2_gscale,
"out_dtype": x.dtype,
}
extra_prepare_args = {
"use_dp": layer.dp_size > 1,
"local_tokens": x.shape[0],
"a1_gscale": a1_gscale,
}
extra_finalize_args = {
"use_dp": layer.dp_size > 1,
"local_tokens": x.shape[0],
}
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
a1_gscale: torch.Tensor,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
def select_nvfp4_gemm_impl(
allow_flashinfer: bool,
moe, # FusedMoEConfig
logger):
moe: FusedMoEConfig,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
allow_flashinfer: bool,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
# lazy import
from vllm.distributed import get_ep_group
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if allow_flashinfer:
flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_backend != "throughput":
raise ValueError(
f"Only throughput backend is supported for FlashInferExperts, "
f"but got {flashinfer_backend}.")
logger.debug_once(
"Initializing FlashInferExperts with throughput backend.")
return FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1,
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=moe.in_dtype,
quant_dtype="nvfp4",
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,