[Quant] Support MXFP4 W4A16 for compressed-tensors MoE models (#32285)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
@@ -193,6 +194,16 @@ def convert_to_nvfp4_moe_kernel_format(
|
||||
)
|
||||
|
||||
|
||||
def make_mxfp4_moe_quant_config(
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig:
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_quant_config(
|
||||
backend: NvFp4MoeBackend,
|
||||
w13_scale: torch.Tensor,
|
||||
|
||||
@@ -52,6 +52,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
NvFp4MoeBackend,
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_mxfp4_moe_quant_config,
|
||||
make_nvfp4_moe_kernel,
|
||||
make_nvfp4_moe_quant_config,
|
||||
select_nvfp4_moe_backend,
|
||||
@@ -79,6 +80,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
is_fp4_marlin_supported,
|
||||
prepare_moe_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_bf16_scales_to_fp8,
|
||||
@@ -145,6 +147,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
format = scheme_dict.get("format")
|
||||
|
||||
if quant_config._is_mxfp4(weight_quant):
|
||||
return CompressedTensorsW4A4Mxfp4MoEMethod(layer.moe_config)
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
# group_size=None means channelwise
|
||||
group_size = weight_quant.group_size or -1
|
||||
@@ -224,6 +229,140 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(self, moe):
|
||||
super().__init__(moe)
|
||||
self.group_size = 32
|
||||
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
requires_grad=False,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.group_size,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
|
||||
)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return make_mxfp4_moe_quant_config(
|
||||
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
layer.w13_weight_packed.data, requires_grad=False
|
||||
)
|
||||
delattr(layer, "w13_weight_packed")
|
||||
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
layer.w2_weight_packed.data, requires_grad=False
|
||||
)
|
||||
delattr(layer, "w2_weight_packed")
|
||||
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config is not None:
|
||||
self.kernel = make_nvfp4_moe_kernel(
|
||||
backend=self.mxfp4_backend,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if isinstance(x, tuple):
|
||||
x_routing, _ = x
|
||||
else:
|
||||
x_routing = x
|
||||
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x_routing,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
assert self.kernel is not None
|
||||
return self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user