[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)
This commit is contained in:
committed by
GitHub
parent
cd719de5cb
commit
39d28108f4
@@ -600,14 +600,20 @@ class FusedMoE(CustomOp):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptFp8MoEMethod,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
|
||||
self.quant_method,
|
||||
(
|
||||
UnquantizedFusedMoEMethod,
|
||||
ModelOptFp8MoEMethod,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
),
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for unquantized "
|
||||
"and ModelOpt FP8 moe for now"
|
||||
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
|
||||
)
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError(
|
||||
@@ -1277,7 +1283,7 @@ class FusedMoE(CustomOp):
|
||||
self._load_combined_w13_weight_scale(
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
param=param,
|
||||
param=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
@@ -1216,7 +1216,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
@@ -1245,7 +1245,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
w13_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype,
|
||||
@@ -1275,7 +1275,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||
data=torch.empty(
|
||||
num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
|
||||
),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
@@ -1296,7 +1298,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
|
||||
data=torch.empty(
|
||||
global_scale_num_experts,
|
||||
2 if self.moe.is_act_and_mul else 1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
@@ -1312,9 +1318,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
gemm1_weight = layer.w13_weight.data
|
||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||
|
||||
if self.allow_flashinfer and (
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and (
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
)
|
||||
and self.moe.is_act_and_mul
|
||||
):
|
||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||
gemm1_weight, gemm1_weight_scale, dim=-2
|
||||
@@ -1324,7 +1334,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
|
||||
|
||||
# Common processing for w13_weight_scale_2
|
||||
if not torch.allclose(
|
||||
if self.moe.is_act_and_mul and not torch.allclose(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
||||
):
|
||||
logger.warning_once(
|
||||
@@ -1437,11 +1447,39 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
w13_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
|
||||
w13_weight = layer.w13_weight
|
||||
intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
|
||||
if intermediate_size_pad:
|
||||
# padding gated activations will require to split w1 and w3
|
||||
# and pad them individually
|
||||
assert not self.moe.is_act_and_mul, (
|
||||
"The intermediate size required padding, "
|
||||
"but padding is not implemented for gated activations"
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(
|
||||
torch.nn.functional.pad(
|
||||
w13_weight, (0, 0, 0, intermediate_size_pad)
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight = Parameter(
|
||||
torch.nn.functional.pad(
|
||||
layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
torch.nn.functional.pad(
|
||||
layer.w2_weight_scale, (0, intermediate_size_pad // 16)
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -1484,7 +1522,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if not self.moe.is_act_and_mul:
|
||||
assert (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
), (
|
||||
"Non-gated activations are only supported by the"
|
||||
" flashinfer CUTLASS backend for modelopt checkpoints"
|
||||
)
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
|
||||
Reference in New Issue
Block a user