[MoE Refactor] Split up compressed_tensors_moe.py (#38960)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-04-06 20:07:54 -04:00
committed by GitHub
parent 00d7b497b3
commit b2b2c5239e
12 changed files with 2770 additions and 2543 deletions

View File

@@ -57,8 +57,8 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w4a4_nvfp4.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_fp8.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod,
)
__all__ = [
"CompressedTensorsMoEMethod",
]

View File

@@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationStrategy,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
layer_name: str,
) -> FusedMoEMethodBase:
# FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map()
unfused_names = [
layer_name + proj_name
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
]
# TODO: refactor this to use expert_mapping and check all layer numbers
all_scheme_dicts = [
quant_config.get_scheme_dict(layer, name) for name in unfused_names
]
scheme_dict = all_scheme_dicts.pop()
# multiple schemes found
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
raise ValueError(
"All MoE projections need to have same "
"quantization scheme but found multiple"
)
if scheme_dict is None: # ignored layer
return UnquantizedFusedMoEMethod(layer.moe_config)
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")
if quant_config._is_mxfp4(weight_quant):
from .compressed_tensors_moe_w4a4_mxfp4 import (
CompressedTensorsW4A4Mxfp4MoEMethod,
)
return CompressedTensorsW4A4Mxfp4MoEMethod(layer.moe_config)
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# group_size=None means channelwise
group_size = weight_quant.group_size or -1
valid_format_and_bits = (
weight_quant.num_bits in WNA16_SUPPORTED_BITS
and format == CompressionFormat.pack_quantized.value
)
if not valid_format_and_bits:
raise ValueError(
"For Fused MoE layers, only format: ",
f"{CompressionFormat.pack_quantized.value} ",
f" and bits: {WNA16_SUPPORTED_BITS} is supported ",
f"but got format: {CompressionFormat.pack_quantized.value} "
f" and bits: {weight_quant.num_bits}",
)
# Prefer to use the MarlinMoE kernel when it is supported.
if (
not check_moe_marlin_supports_layer(layer, group_size)
or current_platform.is_rocm()
):
from .compressed_tensors_moe_wna16 import (
CompressedTensorsWNA16MoEMethod,
)
if (
weight_quant.strategy == QuantizationStrategy.GROUP
and weight_quant.actorder
in (ActivationOrdering.GROUP, ActivationOrdering.DYNAMIC)
):
raise ValueError(
"WNA16MoE is not supported with actorder=group/dynamic."
)
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(
weight_quant, input_quant, layer.moe_config
)
else:
from .compressed_tensors_moe_wna16_marlin import (
CompressedTensorsWNA16MarlinMoEMethod,
)
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_nvfp4_format(weight_quant):
from .compressed_tensors_moe_w4a4_nvfp4 import (
CompressedTensorsW4A4Nvfp4MoEMethod,
)
_is_valid_nvfp4_activations = (
quant_config._is_nvfp4_format(input_quant) or input_quant is None
)
if not _is_valid_nvfp4_activations:
raise ValueError(
"For NVFP4 weights, input quantization must also be NVFP4 format ",
f"or None for NVFP4A16, found {input_quant}",
)
return CompressedTensorsW4A4Nvfp4MoEMethod(
layer.moe_config, layer_name, use_a16=(input_quant is None)
)
elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
):
from .compressed_tensors_moe_w8a8_fp8 import (
CompressedTensorsW8A8Fp8MoEMethod,
)
return CompressedTensorsW8A8Fp8MoEMethod(
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
from .compressed_tensors_moe_w8a8_int8 import (
CompressedTensorsW8A8Int8MoEMethod,
)
return CompressedTensorsW8A8Int8MoEMethod(
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant):
from .compressed_tensors_moe_w4a8_fp8 import (
CompressedTensorsW4A8Fp8MoEMethod,
)
logger.info_once("Using CompressedTensorsW4A8Fp8MoEMethod")
return CompressedTensorsW4A8Fp8MoEMethod(
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
from .compressed_tensors_moe_w4a8_int8 import (
CompressedTensorsW4A8Int8MoEMethod,
)
return CompressedTensorsW4A8Int8MoEMethod(
weight_quant, input_quant, layer.moe_config
)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
)

View File

@@ -0,0 +1,168 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe):
super().__init__(moe)
self.group_size = 32
self.mxfp4_backend = Mxfp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.params_dtype = params_dtype
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
requires_grad=False,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.group_size,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.group_size,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False
)
delattr(layer, "w13_weight_packed")
layer.w2_weight = torch.nn.Parameter(
layer.w2_weight_packed.data, requires_grad=False
)
delattr(layer, "w2_weight_packed")
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression "
"will be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
prepare_moe_fp4_layer_for_marlin(layer)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
self.moe_kernel = make_mxfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
mxfp4_backend=self.mxfp4_backend,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)

View File

@@ -0,0 +1,306 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
logger = init_logger(__name__)
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
moe: FusedMoEConfig,
layer_name: str | None = None,
use_a16: bool = False,
):
super().__init__(moe)
self.group_size = 16
# Select experts implementation.
self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
config=self.moe,
weight_key=kNvfp4Static,
activation_key=None if use_a16 else kNvfp4Dynamic,
)
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.params_dtype = params_dtype
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
requires_grad=False,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Weight Scales
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.group_size,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# Weight Global Scales
w13_weight_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
w2_weight_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
# Input Global Scales
w13_input_scale = torch.nn.Parameter(
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_input_global_scale", w13_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.empty(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_global_scale", w2_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
# NOTE(rob): wN_weight_packed -> wN_weight is because ModularKernelMethod
# requires this naming convention. However, the name change breaks
# reloading because the state dict no longer matches disk. Once we
# remove MKM, we should revert this change to ensure compatibility.
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False
)
delattr(layer, "w13_weight_packed")
layer.w2_weight = torch.nn.Parameter(
layer.w2_weight_packed.data, requires_grad=False
)
delattr(layer, "w2_weight_packed")
# Use a single gscale for w13.
if self.moe.is_act_and_mul and not torch.allclose(
layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
):
logger.warning_once(
"w1_weight_global_scale must match w3_weight_global_scale. "
"Accuracy may be affected.",
)
w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous()
# Shuffle weights into the NvFp4 kernel format.
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = convert_to_nvfp4_moe_kernel_format(
nvfp4_backend=self.nvfp4_backend,
layer=layer,
w13=layer.w13_weight,
w13_scale=layer.w13_weight_scale,
w13_scale_2=(1.0 / w13_weight_global_scale),
a13_scale=(1.0 / layer.w13_input_global_scale),
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_scale_2=(1.0 / layer.w2_weight_global_scale),
a2_scale=(1.0 / layer.w2_input_global_scale),
is_act_and_mul=self.moe.is_act_and_mul,
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w2_weight_scale", w2_scale)
layer.w13_weight_scale_2 = w13_scale_2
layer.w2_weight_scale_2 = w2_scale_2
layer.w13_input_scale = a13_scale
layer.w2_input_scale = a2_scale
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.experts_cls is not None
self.moe_kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
self.moe_kernel.fused_experts.process_weights_after_loading(layer)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w13_scale_2=layer.w13_weight_scale_2,
w2_scale_2=layer.w2_weight_scale_2,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)

View File

@@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEActivationFormat,
FusedMoEExpertsModular,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
int4_w4afp8_moe_quant_config,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
logger = init_logger(__name__)
class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
self.group_size = self.weight_quant.group_size
self.num_bits = self.weight_quant.num_bits
self.packed_factor = 32 // self.num_bits
assert self.weight_quant.symmetric, (
"Only symmetric quantization is supported for W4A8 MoE"
)
assert self.weight_quant.actorder != "group"
assert self.group_size == 128, "Only group size 128 supported for W4A8 MoE"
self.disable_expert_map = False
self.layer_name = layer_name
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# requirement for CUTLASS reorder_tensor
assert hidden_size % 256 == 0, f"{hidden_size=} must be divisible by 256"
assert intermediate_size_per_partition % 256 == 0, (
f"{intermediate_size_per_partition=} must be divisible by 256"
)
# storage type, pack 8xint4 into int32
params_dtype = torch.int32
# WEIGHTS
w13_weight_packed = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // self.packed_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_packed", w13_weight_packed)
set_weight_attrs(w13_weight_packed, extra_weight_attrs)
w2_weight_packed = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // self.packed_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_packed", w2_weight_packed)
set_weight_attrs(w2_weight_packed, extra_weight_attrs)
# SCALES
# weight_scale refers to the group-wise scales
# they are initially loaded as bf16, we will convert to fp8
# after loading
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // self.group_size,
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
hidden_size,
intermediate_size_per_partition // self.group_size,
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-GROUP quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# weight shapes
w2_weight_shape = torch.nn.Parameter(
torch.empty(num_experts, 2), requires_grad=False
)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(
torch.empty(num_experts, 2), requires_grad=False
)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
# don't use input scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer):
device = layer.w13_weight_packed.device
# STRIDES
# A, C
self.a_strides1_c_strides2 = torch.full(
(layer.local_num_experts,),
layer.hidden_size,
device=device,
dtype=torch.int64,
)
self.a_strides2 = torch.full(
(layer.local_num_experts,),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(layer.local_num_experts,),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
# S (group-wise scales)
# sizeof(StrideS) = 16 bytes, so we need to use 2xint64 to encode it
self.s_strides1 = torch.zeros(
(layer.local_num_experts, 2), device=device, dtype=torch.int64
)
self.s_strides1[:, 0] = 2 * layer.intermediate_size_per_partition
self.s_strides2 = torch.zeros(
(layer.local_num_experts, 2), device=device, dtype=torch.int64
)
self.s_strides2[:, 0] = layer.hidden_size
# encode and reorder weight tensors, and get the layout to pass to
# the grouped gemm kernel. `b_strides1/2` specifies the entire layout
convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed)
w13_weight_shuffled, self.b_strides1 = (
ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed)
)
replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled)
convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed)
w2_weight_shuffled, self.b_strides2 = (
ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed)
)
replace_parameter(layer, "w2_weight_packed", w2_weight_shuffled)
# convert bf16 scales to (fp8_scales, channel_scales)
w13_weight_scale, w13_weight_chan_scale = convert_bf16_scales_to_fp8(
self.quant_fp8, layer.w13_weight_scale
)
w2_weight_scale, w2_weight_chan_scale = convert_bf16_scales_to_fp8(
self.quant_fp8, layer.w2_weight_scale
)
# register channel scales
layer.register_parameter(
"w13_weight_chan_scale",
torch.nn.Parameter(w13_weight_chan_scale, requires_grad=False),
)
layer.register_parameter(
"w2_weight_chan_scale",
torch.nn.Parameter(w2_weight_chan_scale, requires_grad=False),
)
# The scales are stored as (E, N, K // 128) but the kernel expects
# (E, K // 128, N) in row-major format, so we need to permute the last 2 dims
# and make it contiguous
w13_weight_scale_packed = ops.cutlass_pack_scale_fp8(
w13_weight_scale.permute(0, 2, 1).contiguous()
)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale_packed)
w2_weight_scale_packed = ops.cutlass_pack_scale_fp8(
w2_weight_scale.permute(0, 2, 1).contiguous()
)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale_packed)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
return super().maybe_make_prepare_finalize(routing_tables)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# Store quantization scales; both per-group and per-channel
# Note we haven't specified the group size here because
# the quant config logic assumes group-wise scaling
# and channel-wise scaling are exclusive.
return int4_w4afp8_moe_quant_config(
w1_scale=layer.w13_weight_scale, # group scale
w2_scale=layer.w2_weight_scale, # group scale
g1_alphas=layer.w13_weight_chan_scale,
g2_alphas=layer.w2_weight_chan_scale,
per_act_token_quant=True, # always use dynamic per-token
per_out_ch_quant=True, # always use per-channel
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEExpertsModular:
assert self.moe_quant_config is not None
assert (
prepare_finalize.activation_format == FusedMoEActivationFormat.Standard
), "BatchedExperts not supported"
from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8
experts: FusedMoEExpertsModular
logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__)
experts = CutlassExpertsW4A8Fp8(
out_dtype=self.moe.in_dtype,
a_strides1=self.a_strides1_c_strides2,
a_strides2=self.a_strides2,
b_strides1=self.b_strides1,
b_strides2=self.b_strides2,
c_strides1=self.c_strides1,
c_strides2=self.a_strides1_c_strides2,
s_strides1=self.s_strides1,
s_strides2=self.s_strides2,
moe_config=self.moe,
quant_config=self.moe_quant_config,
group_size=self.group_size,
)
num_dispatchers = prepare_finalize.num_dispatchers()
self.disable_expert_map = (
num_dispatchers > 1 or not experts.supports_expert_map()
)
return experts
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
)
assert self.moe_quant_config is not None
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_w4a8_fp8,
)
return cutlass_moe_w4a8_fp8(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
topk_weights,
topk_ids,
moe_config=self.moe,
quant_config=self.moe_quant_config,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else layer.expert_map,
a_strides1=self.a_strides1_c_strides2,
a_strides2=self.a_strides2,
b_strides1=self.b_strides1,
b_strides2=self.b_strides2,
c_strides1=self.c_strides1,
c_strides2=self.a_strides1_c_strides2,
s_strides1=self.s_strides1,
s_strides2=self.s_strides2,
group_size=self.group_size,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
@property
def supports_eplb(self) -> bool:
return False

View File

@@ -0,0 +1,349 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import CpuArchEnum, current_platform
logger = init_logger(__name__)
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
- Bias: Same data type as original weights
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
quantized inside the kernel
"""
def __init__(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.weight_quant = weight_quant
self.input_quant = input_quant
# Validate scheme: weights=W4 (channel or group),
# activations=dynamic TOKEN (A8)
# Must be dynamic per-token activations
if (
input_quant.strategy != QuantizationStrategy.TOKEN
or not input_quant.dynamic
):
raise ValueError(
"W4A8-int MoE needs dynamic per-token activation quantization."
)
# Weight can be channel-wise (group_size=None) or group-wise
self.group_size = (
weight_quant.group_size if (weight_quant.group_size is not None) else -1
)
if weight_quant.num_bits != 4:
raise ValueError("This method only supports 4-bit weights (num_bits=4).")
# CPU only
if not current_platform.is_cpu():
raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.")
# Arm: check _dyn ops availability
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
_ = torch.ops.aten._dyn_quant_matmul_4bit
_ = torch.ops.aten._dyn_quant_pack_4bit_weight
except AttributeError as err:
raise RuntimeError(
f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops;
install a newer build."""
) from err
self.static_input_scales = False # always dynamic per token
# ---- parameter creation ----
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Shapes per local rank (TP/EP):
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
# w2 : [E, H, I_local] int8
# Scales:
# channel-wise: group_size=-1 -> per-output-row, single scale per row
# group-wise : group_size=g ->
# per-output-row, (in_features/g) scales
E = num_experts
H = hidden_size
IN = intermediate_size_per_partition
g = self.group_size
# Per-row scale columns
def _n_scale_cols(in_features: int) -> int:
return 1 if g == -1 else (in_features // g)
# Register unpacked int4-as-int8 weights the loader will fill.
w13 = torch.nn.Parameter(
torch.empty(E, 2 * IN, H, dtype=torch.int8), requires_grad=False
)
set_weight_attrs(w13, extra_weight_attrs)
layer.register_parameter("w13_weight", w13)
w2 = torch.nn.Parameter(
torch.empty(E, H, IN, dtype=torch.int8), requires_grad=False
)
set_weight_attrs(w2, extra_weight_attrs)
layer.register_parameter("w2_weight", w2)
# Register scales
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
w13_s = torch.nn.Parameter(
torch.ones(E, 2 * IN, _n_scale_cols(H), dtype=scale_dtype),
requires_grad=False,
)
set_weight_attrs(
w13_s,
{"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs},
)
layer.register_parameter("w13_weight_scale", w13_s)
w2_s = torch.nn.Parameter(
torch.ones(E, H, _n_scale_cols(IN), dtype=scale_dtype), requires_grad=False
)
set_weight_attrs(
w2_s,
{"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs},
)
layer.register_parameter("w2_weight_scale", w2_s)
if self.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(E, 2 * IN, dtype=params_dtype), requires_grad=False
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
# Placeholders for packed weights (will be replaced after packing)
layer.register_parameter(
"w13_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False)
)
set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs)
layer.register_parameter(
"w2_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False)
)
set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs)
# dims for 4 bit fused matmuls
layer.w13_in_features = H
layer.w13_out_features = 2 * IN
layer.w2_in_features = IN
layer.w2_out_features = H
layer.group_size = g
# post-load packing to dyn-4bit KleidiAI kernel's format
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E = layer.w13_weight.shape[0]
H = layer.w13_in_features
I2 = layer.w13_out_features
IN = layer.w2_in_features
g = layer.group_size
def _pack_matrix(
int4_as_int8_2d: torch.Tensor,
scales_2d: torch.Tensor,
bias_1d: torch.Tensor | None,
in_features: int,
out_features: int,
) -> torch.Tensor:
# int4 values are stored as int8 in [-8,7].
# Shift to unsigned nibble and pack pairs along input-dim.
tmp = int4_as_int8_2d.add(8) # [out, in]
uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to(
torch.uint8
) # [out, in//2]
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
scales = scales_2d.to(scale_dtype)
bias = None if bias_1d is None else bias_1d.to(torch.float32)
return torch.ops.aten._dyn_quant_pack_4bit_weight(
uint8_nibbles,
scales,
bias,
g if g != -1 else in_features,
in_features,
out_features,
)
# Pack per expert
w13_packed_list = []
w2_packed_list = []
has_w13_bias = hasattr(layer, "w13_bias") and layer.w13_bias is not None
has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None
for e in range(E):
w13_packed_list.append(
_pack_matrix(
layer.w13_weight[e], # [2I, H]
layer.w13_weight_scale[e], # [2I, H/g or 1]
layer.w13_bias[e] if has_w13_bias else None, # [2I]
H,
I2,
)
)
w2_packed_list.append(
_pack_matrix(
# w2 shape is [H, IN]; we need [out, in] == [H, IN].
layer.w2_weight[e], # [H, IN]
layer.w2_weight_scale[e], # [H, IN/g or 1]
layer.w2_bias[e] if has_w2_bias else None, # [H]
IN,
layer.w2_out_features, # in_features=IN, out_features=H
)
)
# each packed tensor has identical shape per expert; stack on dim 0
w13_packed = torch.stack(w13_packed_list, dim=0)
w2_packed = torch.stack(w2_packed_list, dim=0)
replace_parameter(
layer,
"w13_weight_packed",
torch.nn.Parameter(w13_packed, requires_grad=False),
)
replace_parameter(
layer,
"w2_weight_packed",
torch.nn.Parameter(w2_packed, requires_grad=False),
)
# free raw tensors/scales/bias now that they're packed into the payload.
replace_parameter(
layer, "w13_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False)
)
replace_parameter(
layer, "w2_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False)
)
replace_parameter(
layer,
"w13_weight_scale",
torch.nn.Parameter(torch.empty(0), requires_grad=False),
)
replace_parameter(
layer,
"w2_weight_scale",
torch.nn.Parameter(torch.empty(0), requires_grad=False),
)
if has_w13_bias:
replace_parameter(
layer,
"w13_bias",
torch.nn.Parameter(torch.empty(0), requires_grad=False),
)
if has_w2_bias:
replace_parameter(
layer,
"w2_bias",
torch.nn.Parameter(torch.empty(0), requires_grad=False),
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# CPU dynamic 4-bit MoE path does not use modular kernels or
# fused_experts; quant config is not needed.
return None
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert layer.activation in (
MoEActivation.SILU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
assert layer.expert_map is None, """expert_map/EP not implemented
for CPU dyn-4bit MoE."""
def _act_kind(s: MoEActivation) -> int:
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
if s == MoEActivation.SWIGLUSTEP:
return 0
if s == MoEActivation.SWIGLUOAI:
return 1
if s == MoEActivation.SILU:
return 2
raise ValueError(f"Unknown activation '{s}'")
# Apply topk softmax on router output
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=layer.top_k,
use_grouped_topk=layer.use_grouped_topk,
renormalize=layer.renormalize,
)
return torch.ops._C.dynamic_4bit_int_moe(
x,
topk_ids.to(torch.long),
topk_weights,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w2_out_features,
layer.w2_in_features,
layer.w13_out_features,
layer.group_size,
layer.apply_router_weight_on_input,
int(_act_kind(layer.activation)),
)

View File

@@ -0,0 +1,414 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic128Sym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"""W8A8 FP8 MoE quantization using compressed tensors."""
def __init__(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
per_tensor = (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN
)
if not (per_tensor or per_channel):
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
self.weight_block_size = self.weight_quant.block_structure
assert self.weight_quant.dynamic is not None
else:
self.weight_block_size = None
self.block_quant = self.weight_block_size is not None
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
ct2vllm_weight = {
QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
QuantizationStrategy.BLOCK: kFp8Static128BlockSym,
}
ct2vllm_act = {
QuantizationStrategy.TOKEN: kFp8DynamicTokenSym,
QuantizationStrategy.TENSOR: (
kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym
),
}
weight_key = ct2vllm_weight[self.weight_quant.strategy]
if weight_key == kFp8Static128BlockSym:
activation_key = kFp8Dynamic128Sym
else:
activation_key = ct2vllm_act[self.input_quant.strategy]
# Select Fp8 MoE backend
self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
config=self.moe,
weight_key=weight_key,
activation_key=activation_key,
allow_vllm_cutlass=True,
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
params_dtype = torch.float8_e4m3fn
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.weight_block_size[0],
self.weight_block_size[1],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_num_shards * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
w13_num_shards * intermediate_size_per_partition,
1,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
w13_num_shards
* ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: FusedMoE) -> None:
# Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w13, w13_scale, w13_input_scale
)
w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w2, w2_scale, w2_input_scale
)
# Per tensor kernels require single activation scale. Use the max.
if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
assert w13_input_scale is not None and w2_input_scale is not None
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
w13_input_scale, w2_input_scale
)
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_input_scale", w2_input_scale)
# Per-tensor kernels use a single scale, for W13, but on disk there
# is a separate scale for W1 and W3. Requantize with the max scale.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
w13,
w13_scale,
shard_size=layer.intermediate_size_per_partition,
num_experts=layer.local_num_experts,
is_act_and_mul=self.moe.is_act_and_mul,
)
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
return make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=is_per_token,
per_out_ch_quant=is_per_token,
block_shape=self.weight_block_size,
)
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
@property
def supports_eplb(self) -> bool:
return True

View File

@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
int8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN
)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}"
)
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
params_dtype = torch.int8
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_num_shards * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
w13_num_shards * intermediate_size_per_partition,
1,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return int8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)

View File

@@ -0,0 +1,267 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
# Extract properties from weight_quant
self.num_bits = weight_quant.num_bits
self.packed_factor = 32 // weight_quant.num_bits
self.strategy = weight_quant.strategy
# channelwise is not supported by this kernel
assert weight_quant.strategy == "group"
self.group_size = weight_quant.group_size
# grouped actorder isn't supported by this kernel
assert weight_quant.actorder != "group"
assert weight_quant.symmetric, (
"Only symmetric quantization is supported for MoE"
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs.update(
{"is_transposed": True, "quant_method": self.strategy}
)
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size // self.packed_factor,
w13_num_shards * intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_scales_size = intermediate_size_per_partition
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(
torch.ones(
num_experts,
num_groups_w13,
w13_num_shards * intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(
torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": False})
w2_weight_shape = torch.nn.Parameter(
torch.empty(num_experts, 2), requires_grad=False
)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(
torch.empty(num_experts, 2), requires_grad=False
)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Reconfigure packed weights and scales to match moe_wna16 format
layer.w13_weight_packed = torch.nn.Parameter(
layer.w13_weight_packed.transpose(1, 2).contiguous().view(torch.uint8),
requires_grad=False,
)
layer.w2_weight_packed = torch.nn.Parameter(
layer.w2_weight_packed.transpose(1, 2).contiguous().view(torch.uint8),
requires_grad=False,
)
layer.w13_weight_scale = torch.nn.Parameter(
layer.w13_weight_scale.transpose(1, 2).contiguous(), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
layer.w2_weight_scale.transpose(1, 2).contiguous(), requires_grad=False
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
assert self.num_bits == 4 or self.num_bits == 8
config_builder = (
int4_w4a16_moe_quant_config
if self.num_bits == 4
else int8_w8a16_moe_quant_config
)
return config_builder(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEExpertsModular:
if self.moe.is_lora_enabled:
assert self.moe_quant_config is not None
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.fused_moe import TritonWNA16Experts
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
return TritonWNA16Experts(
moe_config=self.moe, quant_config=self.moe_quant_config
)
else:
raise NotImplementedError(
"TritonExperts requires Triton. "
"Install triton or disable LoRA for MoE."
)
raise NotImplementedError
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@property
def supports_eplb(self) -> bool:
return True

View File

@@ -0,0 +1,575 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from enum import Enum
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_TYPES_MAP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
flashinfer_trtllm_mxint4_moe,
is_flashinfer_mxint4_moe_available,
prepare_static_weights_for_trtllm_mxint4_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new,
marlin_moe_permute_scales,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
logger = init_logger(__name__)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
assert weight_quant.symmetric, (
"Only symmetric quantization is supported for MoE"
)
# Extract properties from weight_quant
self.num_bits = weight_quant.num_bits
self.packed_factor = 32 // weight_quant.num_bits
self.strategy = weight_quant.strategy
self.group_size = weight_quant.group_size
self.actorder = weight_quant.actorder
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
self.use_flashinfer_mxint4_moe = (
is_flashinfer_mxint4_moe_available()
and self.group_size == 32
and weight_quant.num_bits == 4
)
self.kernel_backend = (
"Flashinfer" if self.use_flashinfer_mxint4_moe else "Marlin"
)
logger.info_once(
f"Using {self.kernel_backend} backend for WNA16 MoE "
f"(group_size={self.group_size}, num_bits={self.num_bits})",
scope="local",
)
def get_weight_shape(
self,
weight_name: str,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
num_groups_w2: int | None = None,
num_groups_w13: int | None = None,
) -> tuple[int, int, int]:
"""
Get the shape of the weight based on the weight name, number of experts
hidden size, intermediate size per partition, number of groups for w2,
and number of groups for w13. Pass in num_groups_w2 and num_groups_w13
for weight scales.
"""
if weight_name == "w13_scale":
assert num_groups_w13 is not None, (
"num_groups_w13 must be provided for weight scales"
)
if weight_name == "w2_scale":
assert num_groups_w2 is not None, (
"num_groups_w2 must be provided for weight scales"
)
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
shape_map = {
"w13_weight": {
"Flashinfer": (
num_experts,
w13_num_shards * intermediate_size_per_partition,
hidden_size // self.packed_factor,
),
"Marlin": (
num_experts,
hidden_size // self.packed_factor,
w13_num_shards * intermediate_size_per_partition,
),
},
"w13_scale": {
"Flashinfer": (
num_experts,
w13_num_shards * intermediate_size_per_partition,
num_groups_w13,
),
"Marlin": (
num_experts,
num_groups_w13,
w13_num_shards * intermediate_size_per_partition,
),
},
"w2_weight": {
"Flashinfer": (
num_experts,
hidden_size,
intermediate_size_per_partition // self.packed_factor,
),
"Marlin": (
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
),
},
"w2_scale": {
"Flashinfer": (num_experts, hidden_size, num_groups_w2),
"Marlin": (num_experts, num_groups_w2, hidden_size),
},
}
return shape_map[weight_name][self.kernel_backend]
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
is_transposed = self.kernel_backend != "Flashinfer"
extra_weight_attrs.update(
{"is_transposed": is_transposed, "quant_method": self.strategy}
)
w13_weight = torch.nn.Parameter(
torch.empty(
*self.get_weight_shape(
"w13_weight",
num_experts,
hidden_size,
intermediate_size_per_partition,
),
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
*self.get_weight_shape(
"w2_weight",
num_experts,
hidden_size,
intermediate_size_per_partition,
),
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
load_full_w2 = self.actorder and self.group_size != -1
w2_scales_size = (
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
)
self.is_k_full = (not self.actorder) or (
intermediate_size_per_partition == intermediate_size_full
)
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
w13_scale = torch.nn.Parameter(
torch.ones(
*self.get_weight_shape(
"w13_scale",
num_experts,
hidden_size,
intermediate_size_per_partition,
num_groups_w13=num_groups_w13,
),
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(
torch.ones(
*self.get_weight_shape(
"w2_scale",
num_experts,
hidden_size,
intermediate_size_per_partition,
num_groups_w2=num_groups_w2,
),
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
w2_weight_shape = torch.nn.Parameter(
torch.empty(num_experts, 2), requires_grad=False
)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(
torch.empty(num_experts, 2), requires_grad=False
)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
layer.marlin_state = GPTQMarlinState.REPACK
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device
if self.kernel_backend == "Flashinfer":
dict_weights_mxint4 = prepare_static_weights_for_trtllm_mxint4_moe(
layer.w13_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_packed,
layer.w2_weight_scale,
)
replace_parameter(
layer, "w13_weight_packed", dict_weights_mxint4["gemm1_weights"]
)
replace_parameter(
layer, "w13_weight_scale", dict_weights_mxint4["gemm1_scales"]
)
replace_parameter(
layer, "w2_weight_packed", dict_weights_mxint4["gemm2_weights"]
)
replace_parameter(
layer, "w2_weight_scale", dict_weights_mxint4["gemm2_scales"]
)
return None
is_a_8bit = (
self.marlin_input_dtype is not None
and self.marlin_input_dtype.itemsize == 1
)
if self.marlin_input_dtype == torch.float8_e4m3fn:
# NOTE: for non-zp quantization format only
ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True)
layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512
layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512
# when running models with grouped act order,
# resort to g_idx values provided in checkpoint
if self.actorder == "group":
w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to(
torch.int32
)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to(
torch.int32
)
w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
w13_g_idx_sort_indices[e]
]
w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
else:
layer.w13_weight_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_weight_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_weight_packed,
layer.w13_g_idx_sort_indices,
layer.w13_weight_packed.shape[1] * self.packed_factor,
layer.w13_weight_packed.shape[2],
self.num_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor,
layer.w2_weight_packed.shape[2],
self.num_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_weight_scale,
size_k=layer.w13_weight_packed.shape[2],
size_n=layer.w13_weight_scale.shape[2],
group_size=self.group_size,
is_a_8bit=is_a_8bit,
)
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_weight_scale,
size_k=layer.w2_weight_scale.shape[1]
* (self.group_size if self.group_size != -1 else self.packed_factor),
size_n=layer.w2_weight_scale.shape[2],
group_size=self.group_size,
is_a_8bit=is_a_8bit,
)
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
layer.workspace = marlin_make_workspace_new(device, 4)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.num_bits != 4:
return None
return int4_w4a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEExpertsModular:
assert self.num_bits == 4, "only supporting w4"
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
return MarlinExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
@property
def is_monolithic(self) -> bool:
return self.kernel_backend == "Flashinfer"
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert self.kernel_backend == "Flashinfer"
return flashinfer_trtllm_mxint4_moe(
x=x,
router_logits=router_logits,
w13_weight_packed=layer.w13_weight_packed,
w13_weight_scale=layer.w13_weight_scale,
w2_weight_packed=layer.w2_weight_packed,
w2_weight_scale=layer.w2_weight_scale,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
intermediate_size_per_partition=layer.intermediate_size_per_partition,
local_num_experts=layer.local_num_experts,
ep_rank=layer.ep_rank,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routing_method_type=layer.routing_method_type,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert self.kernel_backend == "Marlin"
return fused_marlin_moe(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
is_k_full=self.is_k_full,
inplace=not self.moe.disable_inplace,
)