[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)

This commit is contained in:
Omer Ullman Argov
2025-11-30 18:02:40 +02:00
committed by GitHub
parent cd719de5cb
commit 39d28108f4
5 changed files with 98 additions and 22 deletions

View File

@@ -600,14 +600,20 @@ class FusedMoE(CustomOp):
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod,
ModelOptNvFp4FusedMoE,
)
if not isinstance(
self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
self.quant_method,
(
UnquantizedFusedMoEMethod,
ModelOptFp8MoEMethod,
ModelOptNvFp4FusedMoE,
),
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
"and ModelOpt FP8 moe for now"
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
)
if not current_platform.is_cuda():
raise NotImplementedError(
@@ -1277,7 +1283,7 @@ class FusedMoE(CustomOp):
self._load_combined_w13_weight_scale(
shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
param=expert_data,
tp_rank=self.tp_rank,
)
return True if return_success else None

View File

@@ -1216,7 +1216,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype,
@@ -1245,7 +1245,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype,
@@ -1275,7 +1275,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
data=torch.empty(
num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
@@ -1296,7 +1298,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
data=torch.empty(
global_scale_num_experts,
2 if self.moe.is_act_and_mul else 1,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
@@ -1312,9 +1318,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
if self.allow_flashinfer and (
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
if (
self.allow_flashinfer
and (
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
)
and self.moe.is_act_and_mul
):
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2
@@ -1324,7 +1334,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
# Common processing for w13_weight_scale_2
if not torch.allclose(
if self.moe.is_act_and_mul and not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
):
logger.warning_once(
@@ -1437,11 +1447,39 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_blockscale_swizzled, requires_grad=False
)
w13_weight = layer.w13_weight
intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
if intermediate_size_pad:
# padding gated activations will require to split w1 and w3
# and pad them individually
assert not self.moe.is_act_and_mul, (
"The intermediate size required padding, "
"but padding is not implemented for gated activations"
)
layer.w13_weight = Parameter(
torch.nn.functional.pad(
w13_weight, (0, 0, 0, intermediate_size_pad)
),
requires_grad=False,
)
layer.w2_weight = Parameter(
torch.nn.functional.pad(
layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
),
requires_grad=False,
)
layer.w2_weight_scale = Parameter(
torch.nn.functional.pad(
layer.w2_weight_scale, (0, intermediate_size_pad // 16)
),
requires_grad=False,
)
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
layer.w2_weight_scale = Parameter(
w2_blockscale_swizzled, requires_grad=False
)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -1484,7 +1522,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported."
if not self.moe.is_act_and_mul:
assert (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
), (
"Non-gated activations are only supported by the"
" flashinfer CUTLASS backend for modelopt checkpoints"
)
if (
self.allow_flashinfer