[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return AWQMoEMethod(quant_args_marlin)
|
||||
return AWQMoEMethod(quant_args_marlin, layer.moe)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
|
||||
@@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix)
|
||||
return GPTQMarlinMoEMethod(quant_args_marlin)
|
||||
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
|
||||
@@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig):
|
||||
}
|
||||
awq_marlin_config = AWQMarlinConfig.from_config(
|
||||
marlin_compatible_config_dict)
|
||||
return AWQMoEMethod(awq_marlin_config)
|
||||
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
@@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMoEMethod(self)
|
||||
return AWQMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
|
||||
class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: AWQMarlinConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.weight_bits != 4:
|
||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||
@@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `AWQMoEMethod` yet.")
|
||||
@@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
@@ -535,4 +543,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace)
|
||||
workspace=layer.workspace)
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
@@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return BitsAndBytesMoEMethod(self)
|
||||
return BitsAndBytesMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
@@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
quant_config: The BitsAndBytes quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: BitsAndBytesConfig):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: BitsAndBytesConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
try:
|
||||
import bitsandbytes
|
||||
if version.parse(
|
||||
@@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
raise ImportError("Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer.") from err
|
||||
self.topk_indices_dtype = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
@@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -11,20 +11,21 @@ from compressed_tensors.quantization import (ActivationOrdering,
|
||||
QuantizationStrategy)
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_kernel,
|
||||
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales)
|
||||
@@ -58,6 +59,9 @@ __all__ = [
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init_(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
@@ -81,18 +85,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
"WNA16MoE is not supported with actorder=group/dynamic."
|
||||
)
|
||||
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config,
|
||||
layer.moe_config)
|
||||
else:
|
||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(
|
||||
quant_config, layer.moe_config)
|
||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A4MoeMethod()
|
||||
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
|
||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config,
|
||||
layer.moe_config)
|
||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
|
||||
layer.moe_config)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||
@@ -100,15 +108,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support)
|
||||
super().__init__(moe)
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.group_size = 16
|
||||
self.fused_experts = None # type: ignore[assignment]
|
||||
self.layer = layer
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@@ -265,19 +274,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||
(layer.w2_input_global_scale), requires_grad=False)
|
||||
|
||||
def maybe_swap_experts_impl(self, moe_parallel_config):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if not self.allow_flashinfer:
|
||||
return
|
||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||
moe_parallel_config)
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize, moe):
|
||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return the appropriate GEMM experts implementation."""
|
||||
assert moe is not None and prepare_finalize is not None
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||
select_nvfp4_gemm_impl)
|
||||
|
||||
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
moe,
|
||||
g1_alphas=self.layer.g1_alphas,
|
||||
g2_alphas=self.layer.g2_alphas,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
a2_gscale=self.layer.w2_input_scale_quant,
|
||||
allow_flashinfer=self.allow_flashinfer,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -301,6 +327,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for "
|
||||
"`CompressedTensorsW4A4MoeMethod` yet.")
|
||||
@@ -317,6 +345,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
@@ -340,15 +369,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
# FlashInfer fused experts path
|
||||
if self.fused_experts is not None:
|
||||
return flashinfer_fp4_cutlass_moe_forward(
|
||||
self.fused_experts,
|
||||
layer,
|
||||
x,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@@ -376,7 +412,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
device=x.device,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
||||
x.dtype)
|
||||
|
||||
@@ -384,15 +419,16 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations")
|
||||
self.topk_indices_dtype = None
|
||||
|
||||
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||
and self.input_quant.strategy
|
||||
@@ -429,7 +465,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.weight_quant, self.input_quant)
|
||||
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
|
||||
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
|
||||
self.fused_experts = None # type: ignore[assignment]
|
||||
self.disable_expert_map = False
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@@ -614,25 +649,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# cutlass path
|
||||
if self.use_cutlass:
|
||||
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
|
||||
|
||||
use_batched_format = (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts)
|
||||
experts: FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
num_dispatchers = prepare_finalize.num_dispatchers()
|
||||
num_experts = (moe.num_local_experts
|
||||
if use_batched_format else moe.num_experts)
|
||||
|
||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||
|
||||
experts = CutlassExpertsFp8(
|
||||
num_experts,
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_batched_format=use_batched_format,
|
||||
)
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
logger.debug("CutlassBatchedExpertsFp8(%s)",
|
||||
self.__class__.__name__)
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
moe.num_local_experts,
|
||||
num_dispatchers,
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
)
|
||||
else:
|
||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||
experts = CutlassExpertsFp8(
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
)
|
||||
|
||||
self.disable_expert_map = (num_dispatchers > 1
|
||||
or not experts.supports_expert_map())
|
||||
@@ -834,9 +875,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"weights")
|
||||
@@ -934,6 +977,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for "
|
||||
@@ -951,7 +996,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
@@ -975,9 +1021,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -1233,6 +1281,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for "
|
||||
@@ -1251,7 +1301,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
@@ -1279,9 +1330,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -1459,6 +1512,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for "
|
||||
"`CompressedTensorsWNA16MoEMethod` yet.")
|
||||
@@ -1475,7 +1530,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
|
||||
@@ -6,7 +6,8 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ExpertsInt8MoEMethod(self)
|
||||
return ExpertsInt8MoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ExpertsInt8Config):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ExpertsInt8Config,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ExpertsInt8MoEMethod` yet.")
|
||||
@@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
@@ -142,7 +141,7 @@ class Fp8Config(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Fp8MoEMethod(self)
|
||||
return Fp8MoEMethod(self, layer.moe_config)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
@@ -479,9 +478,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
@@ -529,15 +527,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||
"platform.")
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@@ -1033,7 +1022,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
elif self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -1052,6 +1041,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
@@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -58,7 +59,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
return GGUFEmbeddingMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return GGUFMoEMethod(self)
|
||||
return GGUFMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
@@ -445,7 +446,12 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GGUFConfig):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: GGUFConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@@ -525,6 +531,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `GGUFMoEMethod` yet.")
|
||||
@@ -545,7 +553,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
||||
topk_weights, topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
|
||||
@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
@@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE Marlin method with quantization."""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: GPTQMarlinConfig,
|
||||
moe: FusedMoEConfig,
|
||||
) -> None:
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
@@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
|
||||
@@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
|
||||
@@ -12,7 +12,9 @@ import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
@@ -22,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_kernel,
|
||||
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
||||
@@ -177,7 +179,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
elif isinstance(layer, Attention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptFp8MoEMethod(self)
|
||||
return ModelOptFp8MoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
@@ -273,7 +275,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptFp8Config,
|
||||
moe: FusedMoEConfig,
|
||||
) -> None:
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported)
|
||||
@@ -454,6 +461,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||
@@ -484,6 +493,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts)
|
||||
@@ -699,7 +709,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
elif isinstance(layer, Attention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptNvFp4FusedMoE(self)
|
||||
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
|
||||
return None
|
||||
|
||||
|
||||
@@ -923,10 +933,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
quant_config: NVFP4 Quant Config
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptNvFp4Config,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support)
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
self.layer = layer
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
@@ -952,27 +969,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
|
||||
|
||||
def maybe_swap_experts_impl(
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
):
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if not self.allow_flashinfer:
|
||||
return
|
||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||
moe_parallel_config)
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
# This method update self.fused_experts
|
||||
# only prepare_finalize is not None call select_gemm_impl
|
||||
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
|
||||
# when it's not called(TP case), we still have 2 kernels to use.
|
||||
def select_gemm_impl(self, prepare_finalize,
|
||||
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
assert moe is not None and prepare_finalize is not None
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||
select_nvfp4_gemm_impl)
|
||||
|
||||
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
moe,
|
||||
g1_alphas=self.layer.g1_alphas,
|
||||
g2_alphas=self.layer.g2_alphas,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
a2_gscale=self.layer.w2_input_scale_quant,
|
||||
allow_flashinfer=self.allow_flashinfer,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
|
||||
def uses_weight_scale_2_pattern(self) -> bool:
|
||||
"""
|
||||
@@ -1362,7 +1387,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
if self.use_marlin:
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
@@ -1404,21 +1430,28 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
device=x.device,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
assert self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
out = flashinfer_fp4_cutlass_moe_forward(
|
||||
self.fused_experts,
|
||||
layer,
|
||||
x,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
out = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return MoeWNA16Method(self)
|
||||
return MoeWNA16Method(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
@@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MoeWNA16Config):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: MoeWNA16Config,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `MoeWNA16Method` yet.")
|
||||
@@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
|
||||
@@ -82,7 +82,7 @@ class Mxfp4Config(QuantizationConfig):
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
super().__init__(moe)
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.use_marlin = self._should_use_marlin()
|
||||
|
||||
@@ -7,7 +7,8 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE)
|
||||
@@ -25,6 +26,9 @@ __all__ = [
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||
@@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
|
||||
module.moe_config)
|
||||
elif quant_config._is_mx_fp4(weight_config, input_config):
|
||||
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
|
||||
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
|
||||
module.moe_config)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
|
||||
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
|
||||
Any]):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
@@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
|
||||
@@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
@@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
|
||||
Any]):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
@@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
@@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
out = fused_experts(
|
||||
x,
|
||||
|
||||
@@ -10,7 +10,8 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig):
|
||||
if isinstance(layer, LinearBase):
|
||||
return RTNLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return RTNMoEMethod(self)
|
||||
return RTNMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
@@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase):
|
||||
|
||||
class RTNMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: RTNConfig):
|
||||
def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `RTNMoEMethod` yet.")
|
||||
@@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
@@ -3,33 +3,30 @@
|
||||
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
__all__ = [
|
||||
"is_flashinfer_fp4_cutlass_moe_available",
|
||||
"reorder_w1w3_to_w3w1",
|
||||
"build_flashinfer_fp4_cutlass_moe_kernel",
|
||||
"flashinfer_fp4_cutlass_moe_forward",
|
||||
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
|
||||
]
|
||||
|
||||
|
||||
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda()
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100))
|
||||
|
||||
|
||||
@@ -49,105 +46,33 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
|
||||
dim=dim).contiguous())
|
||||
|
||||
|
||||
def build_flashinfer_fp4_cutlass_moe_kernel(
|
||||
moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel:
|
||||
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel"""
|
||||
experts = FlashInferExperts(
|
||||
use_nvfp4_w4a4=True,
|
||||
use_dp=moe_parallel_config.dp_size > 1,
|
||||
ep_rank=moe_parallel_config.ep_rank,
|
||||
ep_size=moe_parallel_config.ep_size,
|
||||
tp_rank=moe_parallel_config.tp_rank,
|
||||
tp_size=moe_parallel_config.tp_size,
|
||||
)
|
||||
logger.debug_once("FlashInferExperts (util)")
|
||||
return mk.FusedMoEModularKernel(
|
||||
FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8),
|
||||
experts,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fp4_cutlass_moe_forward(
|
||||
fused_experts: mk.FusedMoEModularKernel,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight,
|
||||
layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!")
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
a2_gscale = layer.w2_input_scale_quant
|
||||
|
||||
extra_expert_args = {
|
||||
"g1_alphas": layer.g1_alphas,
|
||||
"g2_alphas": layer.g2_alphas,
|
||||
# Avoid confusion with a1_scale and a2_scale
|
||||
# where are batch size related.
|
||||
"a1_gscale": a1_gscale,
|
||||
"a2_gscale": a2_gscale,
|
||||
"out_dtype": x.dtype,
|
||||
}
|
||||
extra_prepare_args = {
|
||||
"use_dp": layer.dp_size > 1,
|
||||
"local_tokens": x.shape[0],
|
||||
"a1_gscale": a1_gscale,
|
||||
}
|
||||
extra_finalize_args = {
|
||||
"use_dp": layer.dp_size > 1,
|
||||
"local_tokens": x.shape[0],
|
||||
}
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args,
|
||||
extra_prepare_args=extra_prepare_args,
|
||||
extra_finalize_args=extra_finalize_args,
|
||||
)
|
||||
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe: FusedMoEConfig,
|
||||
a1_gscale: torch.Tensor,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
|
||||
|
||||
|
||||
def select_nvfp4_gemm_impl(
|
||||
allow_flashinfer: bool,
|
||||
moe, # FusedMoEConfig
|
||||
logger):
|
||||
moe: FusedMoEConfig,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
allow_flashinfer: bool,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||
|
||||
# lazy import
|
||||
from vllm.distributed import get_ep_group
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
if allow_flashinfer:
|
||||
flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_backend != "throughput":
|
||||
raise ValueError(
|
||||
f"Only throughput backend is supported for FlashInferExperts, "
|
||||
f"but got {flashinfer_backend}.")
|
||||
logger.debug_once(
|
||||
"Initializing FlashInferExperts with throughput backend.")
|
||||
return FlashInferExperts(
|
||||
use_nvfp4_w4a4=True,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_dtype="nvfp4",
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
|
||||
Reference in New Issue
Block a user