[Refactor] Move FusedMoE hidden_size roundup to quant_method (#34285)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
@@ -229,9 +229,6 @@ class FusedMoEQuantConfig:
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
is_nvfp4_scale_swizzled: bool = True
|
||||
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
|
||||
hidden_pad: int = 0
|
||||
intermediate_pad: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
@@ -1172,6 +1169,11 @@ class FusedMoEConfig:
|
||||
# Defaults to in_dtype if not specified.
|
||||
router_logits_dtype: torch.dtype | None = None
|
||||
|
||||
# Defaults to hidden_dim if not specified.
|
||||
hidden_dim_unpadded: int | None = None
|
||||
# Defaults to intermediate_size_per_partition if not specified.
|
||||
intermediate_size_per_partition_unpadded: int | None = None
|
||||
|
||||
moe_backend: str = "auto"
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
has_bias: bool = False
|
||||
@@ -1195,6 +1197,13 @@ class FusedMoEConfig:
|
||||
if self.router_logits_dtype is None:
|
||||
self.router_logits_dtype = self.in_dtype
|
||||
|
||||
if self.hidden_dim_unpadded is None:
|
||||
self.hidden_dim_unpadded = self.hidden_dim
|
||||
if self.intermediate_size_per_partition_unpadded is None:
|
||||
self.intermediate_size_per_partition_unpadded = (
|
||||
self.intermediate_size_per_partition
|
||||
)
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@@ -9,6 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@@ -65,6 +66,38 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
"""
|
||||
return False
|
||||
|
||||
def maybe_roundup_sizes(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
act_dtype: torch.dtype,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Given layer hidden size and intermediate size per partition and MoE
|
||||
configurations, round up hidden_size and intermediate_size_per_partition
|
||||
if necessary.
|
||||
|
||||
Args:
|
||||
hidden_size: Layer hidden-size
|
||||
intermediate_size_per_partition: Intermediate size per partition for
|
||||
the layer.
|
||||
act_dtype: Data type of the layer activations.
|
||||
moe_parallel_config: Fused MoE parallelization strategy configuration.
|
||||
|
||||
Return:
|
||||
A tuple of (rounded_hidden_size, rounded_intermediate_size_per_partition),
|
||||
where:
|
||||
- rounded_hidden_size is the possibly rounded up hidden size.
|
||||
- rounded_intermediate_size_per_partition is the possibly rounded
|
||||
up intermediate size per partition.
|
||||
"""
|
||||
from .all2all_utils import maybe_roundup_layer_hidden_size
|
||||
|
||||
return maybe_roundup_layer_hidden_size(
|
||||
hidden_size, act_dtype, moe_parallel_config
|
||||
), intermediate_size_per_partition
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
|
||||
@@ -428,13 +428,9 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
|
||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check: when weights are padded (e.g. hidden_size padded for
|
||||
# GFX950 swizzle), unpadded_K_w1 carries the original dimension.
|
||||
expected_K_w1 = unpadded_K_w1 if unpadded_K_w1 is not None else w1.shape[-2]
|
||||
assert hidden_states.shape[-1] == expected_K_w1, (
|
||||
f"hidden_states K={hidden_states.shape[-1]} != "
|
||||
f"expected K={expected_K_w1} (w1 K={w1.shape[-2]})"
|
||||
)
|
||||
# Shape check: weights are padded (e.g. hidden_size padded for
|
||||
# GFX950 swizzle).
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
E, _, N = w1.shape
|
||||
@@ -494,12 +490,6 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
unpadded_K=unpadded_K_w2,
|
||||
)
|
||||
|
||||
# When hidden_size was padded for alignment (e.g. GFX950 swizzle),
|
||||
# the kernel output has the padded dimension. Slice back to the
|
||||
# original hidden_size so downstream layers see the expected shape.
|
||||
if unpadded_N_w2 is not None and intermediate_cache3.shape[-1] != unpadded_N_w2:
|
||||
intermediate_cache3 = intermediate_cache3[..., :unpadded_N_w2].contiguous()
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
|
||||
@@ -210,42 +210,6 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
|
||||
)
|
||||
|
||||
|
||||
# TODO(rob): move this down to the kernel.
|
||||
def maybe_roundup_hidden_size(
|
||||
hidden_size: int,
|
||||
act_dtype: torch.dtype,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
is_lora_enabled: bool,
|
||||
model_type: str | None,
|
||||
) -> int:
|
||||
"""
|
||||
Given layer hidden size and MoE configurations, round up hidden_size
|
||||
if necessary.
|
||||
|
||||
Args:
|
||||
hidden_size: Layer hidden-size
|
||||
act_dtype: Data type of the layer activations.
|
||||
moe_parallel_config: Fused MoE parallelization strategy configuration.
|
||||
is_lora_enabled: True if the engine is enabled with LoRA. This
|
||||
is used in the case of mxfp4 quantization in selecting the
|
||||
MxFP4Backend.
|
||||
model_type: for checking if gpt-oss
|
||||
|
||||
Return:
|
||||
Rounded up hidden_size if rounding up is required based on the configs.
|
||||
Original hidden size otherwise.
|
||||
"""
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_roundup_layer_hidden_size,
|
||||
)
|
||||
|
||||
hidden_size = maybe_roundup_layer_hidden_size(
|
||||
hidden_size, act_dtype, moe_parallel_config
|
||||
)
|
||||
|
||||
return hidden_size
|
||||
|
||||
|
||||
# --8<-- [start:fused_moe]
|
||||
@CustomOp.register("fused_moe")
|
||||
class FusedMoE(CustomOp):
|
||||
@@ -459,7 +423,7 @@ class FusedMoE(CustomOp):
|
||||
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
|
||||
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
|
||||
@@ -501,28 +465,13 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
|
||||
|
||||
# Round up hidden size before creating moe_config.
|
||||
# This way moe_config is created with the correct hidden_size from the start.
|
||||
unpadded_hidden_size = hidden_size
|
||||
self.model_type = (
|
||||
self.vllm_config.model_config.hf_config.model_type
|
||||
if self.vllm_config.model_config is not None
|
||||
else None
|
||||
)
|
||||
hidden_size = maybe_roundup_hidden_size(
|
||||
hidden_size=hidden_size,
|
||||
act_dtype=moe_in_dtype,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
is_lora_enabled=vllm_config.lora_config is not None,
|
||||
model_type=self.model_type,
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.moe_config: FusedMoEConfig = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
||||
hidden_dim_unpadded=hidden_size,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
intermediate_size_per_partition_unpadded=intermediate_size_per_partition,
|
||||
num_local_experts=self.local_num_experts,
|
||||
num_logical_experts=self.logical_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
@@ -567,13 +516,6 @@ class FusedMoE(CustomOp):
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
self.quant_method: FusedMoEMethodBase = _get_quant_method()
|
||||
|
||||
# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
|
||||
# and intermediate_size in moe_config during __init__. Sync
|
||||
# self.hidden_size so downstream consumers (e.g. LoRA) see the
|
||||
# padded value.
|
||||
if self.moe_config.hidden_dim != self.hidden_size:
|
||||
self.hidden_size = self.moe_config.hidden_dim
|
||||
|
||||
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
|
||||
@@ -591,11 +533,24 @@ class FusedMoE(CustomOp):
|
||||
f"EPLB is not supported {self.quant_method.__class__.__name__}."
|
||||
)
|
||||
|
||||
# Round up hidden size and update moe_config.
|
||||
hidden_size, intermediate_size_per_partition = (
|
||||
self.quant_method.maybe_roundup_sizes(
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
moe_in_dtype,
|
||||
self.moe_parallel_config,
|
||||
)
|
||||
)
|
||||
self.moe_config.hidden_dim = hidden_size
|
||||
self.moe_config.intermediate_size_per_partition = (
|
||||
intermediate_size_per_partition
|
||||
)
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": self.hidden_size,
|
||||
"unpadded_hidden_size": unpadded_hidden_size,
|
||||
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition": intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
"global_num_experts": self.global_num_experts,
|
||||
@@ -933,9 +888,17 @@ class FusedMoE(CustomOp):
|
||||
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
|
||||
# and we're not loading the full weight
|
||||
if not load_full and loaded_weight.ndim > 0:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
# Handle padding: loaded_weight might be smaller than shard_size on last
|
||||
# TP rank
|
||||
start_offset = shard_size * tp_rank
|
||||
available = loaded_weight.shape[shard_dim] - start_offset
|
||||
if available <= 0:
|
||||
# If there is no available weight to load for this TP rank
|
||||
# (can happen on last TP rank with padding), we can skip
|
||||
# loading and return early
|
||||
return
|
||||
narrow_size = min(shard_size, available)
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_offset, narrow_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
@@ -944,6 +907,13 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
|
||||
# Handle padding: if loaded_weight is smaller than expert_data (can happen
|
||||
# on last TP shard with padding), copy to top-left corner
|
||||
if expert_data.shape != loaded_weight.shape:
|
||||
expert_data = expert_data[
|
||||
: loaded_weight.shape[0], : loaded_weight.shape[1]
|
||||
]
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(
|
||||
@@ -961,10 +931,24 @@ class FusedMoE(CustomOp):
|
||||
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
|
||||
# and we're not loading the full weight
|
||||
if not load_full and loaded_weight.ndim > 0:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
# Handle padding: loaded_weight might be smaller than shard_size on last
|
||||
# TP rank
|
||||
start_offset = shard_size * tp_rank
|
||||
available = loaded_weight.shape[shard_dim] - start_offset
|
||||
if available <= 0:
|
||||
# If there is no available weight to load for this TP rank
|
||||
# (can happen on last TP rank with padding), we can skip
|
||||
# loading and return early
|
||||
return
|
||||
narrow_size = min(shard_size, available)
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_offset, narrow_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
# Handle padding: if loaded_weight is smaller than expert_data (can happen
|
||||
# on last TP shard with padding), copy to top-left corner
|
||||
if expert_data.shape != loaded_weight.shape:
|
||||
expert_data = expert_data[
|
||||
: loaded_weight.shape[0], : loaded_weight.shape[1]
|
||||
]
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_single_value(
|
||||
@@ -1549,6 +1533,14 @@ class FusedMoE(CustomOp):
|
||||
]
|
||||
]
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.moe_config.hidden_dim
|
||||
|
||||
@property
|
||||
def intermediate_size_per_partition(self) -> int:
|
||||
return self.moe_config.intermediate_size_per_partition
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = (
|
||||
f"global_num_experts={self.global_num_experts}, "
|
||||
|
||||
@@ -779,8 +779,6 @@ def make_mxfp4_moe_quant_config(
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
hidden_pad: int = 0,
|
||||
intermediate_pad: int = 0,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
|
||||
if mxfp4_backend in (
|
||||
@@ -802,16 +800,12 @@ def make_mxfp4_moe_quant_config(
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.CK,
|
||||
):
|
||||
config = mxfp4_w4a16_moe_quant_config(
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
if mxfp4_backend == Mxfp4MoeBackend.CK:
|
||||
config.hidden_pad = hidden_pad
|
||||
config.intermediate_pad = intermediate_pad
|
||||
return config
|
||||
else:
|
||||
return ocp_mx_moe_quant_config(
|
||||
quant_dtype="mxfp4",
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
@@ -186,6 +187,7 @@ def rocm_aiter_fused_experts(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
moe_config: FusedMoEConfig,
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -276,6 +278,17 @@ def rocm_aiter_fused_experts(
|
||||
"Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
)
|
||||
|
||||
# Compute padding on-the-fly for CK MXFP4 kernels
|
||||
hidden_pad = 0
|
||||
intermediate_pad = 0
|
||||
assert moe_config.hidden_dim_unpadded is not None
|
||||
assert moe_config.intermediate_size_per_partition_unpadded is not None
|
||||
hidden_pad = hidden_states.shape[1] - moe_config.hidden_dim_unpadded
|
||||
intermediate_pad = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
- moe_config.intermediate_size_per_partition_unpadded
|
||||
)
|
||||
|
||||
return rocm_aiter_ops.fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
@@ -292,8 +305,8 @@ def rocm_aiter_fused_experts(
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
num_local_tokens=num_local_tokens,
|
||||
output_dtype=output_dtype,
|
||||
hidden_pad=quant_config.hidden_pad,
|
||||
intermediate_pad=quant_config.intermediate_pad,
|
||||
hidden_pad=hidden_pad,
|
||||
intermediate_pad=intermediate_pad,
|
||||
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
|
||||
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
|
||||
)
|
||||
@@ -419,6 +432,7 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.quant_config,
|
||||
moe_config=self.moe_config,
|
||||
a1q_scale=a1q_scale,
|
||||
num_local_tokens=num_local_tokens,
|
||||
output_dtype=output.dtype,
|
||||
|
||||
@@ -715,8 +715,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
@@ -2274,8 +2272,6 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
@@ -672,8 +672,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
@@ -1011,8 +1009,6 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
|
||||
@@ -107,18 +108,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
# Round up dims once based on backend. This mutates the shared
|
||||
# FusedMoEConfig in-place so that create_weights() and all
|
||||
# downstream code see the padded dimensions. This must happen
|
||||
# before create_weights() is called.
|
||||
self.moe.hidden_dim, self.moe.intermediate_size_per_partition = (
|
||||
mxfp4_round_up_hidden_size_and_intermediate_size(
|
||||
self.mxfp4_backend,
|
||||
self.moe.hidden_dim,
|
||||
self.moe.intermediate_size_per_partition,
|
||||
)
|
||||
)
|
||||
|
||||
# Used for triton kernel precision configs
|
||||
self.w13_precision_config = None
|
||||
self.w2_precision_config = None
|
||||
@@ -129,6 +118,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# so can skip the padding in the forward before applying the moe method
|
||||
return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
|
||||
|
||||
def maybe_roundup_sizes(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
act_dtype: torch.dtype,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> tuple[int, int]:
|
||||
hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
act_dtype=act_dtype,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
)
|
||||
return mxfp4_round_up_hidden_size_and_intermediate_size(
|
||||
self.mxfp4_backend, hidden_size, intermediate_size_per_partition
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -143,32 +149,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
scale_dtype = torch.uint8
|
||||
mxfp4_block = 32
|
||||
|
||||
# Use pre-rounded sizes from config
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad = (
|
||||
self.moe.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_size = hidden_size = self.moe.hidden_dim
|
||||
|
||||
# Expose padded dimensions on the layer for LoRA and Marlin code
|
||||
# that reads layer.hidden_size / layer.intermediate_size_per_partition.
|
||||
layer.params_dtype = params_dtype
|
||||
layer.num_experts = num_experts
|
||||
layer.hidden_size = hidden_size
|
||||
layer.intermediate_size_per_partition = (
|
||||
intermediate_size_per_partition_after_pad
|
||||
)
|
||||
|
||||
# CK (gfx950) padding info for rocm_aiter_ops.fused_moe()
|
||||
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
|
||||
self.intermediate_pad = (
|
||||
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
|
||||
)
|
||||
self.intermediate_size = intermediate_size_per_partition
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
@@ -180,7 +170,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
@@ -194,7 +184,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -206,7 +196,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
intermediate_size_per_partition // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -218,7 +208,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -368,8 +358,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
hidden_pad=self.hidden_pad,
|
||||
intermediate_pad=self.intermediate_pad,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
mxfp4_w4a8_moe_quant_config,
|
||||
@@ -27,13 +28,13 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
|
||||
Mxfp4MoeBackend,
|
||||
mxfp4_round_up_hidden_size_and_intermediate_size,
|
||||
select_mxfp4_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||
_swizzle_mxfp4,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
@@ -49,7 +50,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -173,8 +173,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
@@ -182,7 +180,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
@@ -194,7 +192,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
@@ -461,6 +459,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=layer.moe_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
@@ -527,7 +526,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
):
|
||||
params_dtype = torch.uint32
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // 8, # INT32 packing for W4
|
||||
@@ -536,7 +535,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // 8, # INT32 packing for W4
|
||||
@@ -649,6 +648,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=layer.moe_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
|
||||
@@ -702,6 +702,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
self.mxfp4_backend: Mxfp4MoeBackend | None = None
|
||||
if self.ocp_mx_scheme == "w_mxfp4":
|
||||
self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe)
|
||||
elif self.ocp_mx_scheme.startswith("w_mxfp4"):
|
||||
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
|
||||
self.mxfp4_backend = Mxfp4MoeBackend.NONE
|
||||
|
||||
if self.input_quant is not None:
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
@@ -734,36 +737,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
self.emulate = (
|
||||
not current_platform.supports_mx()
|
||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
||||
|
||||
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
|
||||
# alignment requirements. When violated (e.g. MiniMax-M2.1 with
|
||||
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
|
||||
# "device_gemm ... does not support this GEMM problem".
|
||||
# Fall back to emulation in that case.
|
||||
# For gpt_oss models, create_weights rounds up the dimensions
|
||||
# internally, so the alignment check is skipped.
|
||||
if (
|
||||
not self.emulate
|
||||
and self.use_rocm_aiter_moe
|
||||
and self.ocp_mx_scheme is not None
|
||||
and self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
and self.model_type != "gpt_oss"
|
||||
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
|
||||
):
|
||||
logger.warning_once(
|
||||
"AITER CK MXFP4 MoE GEMM does not support "
|
||||
"intermediate_size_per_partition=%d (not a multiple of %d). "
|
||||
"This typically happens when intermediate_size / "
|
||||
"tensor_parallel_size produces an incompatible dimension. "
|
||||
"Falling back to emulation mode. To avoid this overhead, "
|
||||
"use a compatible tensor_parallel_size or set "
|
||||
"VLLM_ROCM_USE_AITER_MOE=0.",
|
||||
moe.intermediate_size_per_partition,
|
||||
CK_MXFP4_MOE_DIM_ALIGNMENT,
|
||||
)
|
||||
self.use_rocm_aiter_moe = False
|
||||
self.emulate = True
|
||||
) and (
|
||||
self.mxfp4_backend is None
|
||||
or self.mxfp4_backend is Mxfp4MoeBackend.NONE
|
||||
or not self.use_rocm_aiter_moe
|
||||
)
|
||||
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
@@ -780,6 +758,27 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"The current mode supports native MoE MXFP4 computation"
|
||||
)
|
||||
|
||||
def maybe_roundup_sizes(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
act_dtype: torch.dtype,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> tuple[int, int]:
|
||||
hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
act_dtype=act_dtype,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
)
|
||||
if self.mxfp4_backend is not None:
|
||||
hidden_size, intermediate_size_per_partition = (
|
||||
mxfp4_round_up_hidden_size_and_intermediate_size(
|
||||
self.mxfp4_backend, hidden_size, intermediate_size_per_partition
|
||||
)
|
||||
)
|
||||
return hidden_size, intermediate_size_per_partition
|
||||
|
||||
def get_packed_dim(self, dim: int, quant_dtype: str):
|
||||
if quant_dtype == "mxfp4":
|
||||
assert dim % 2 == 0
|
||||
@@ -805,40 +804,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
params_dtype = torch.uint8
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
if self.model_type == "gpt_oss":
|
||||
if current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256
|
||||
)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64
|
||||
)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
|
||||
self.unpadded_hidden_size = extra_weight_attrs.get(
|
||||
"unpadded_hidden_size", hidden_size
|
||||
)
|
||||
|
||||
# On GFX950, the GFX950MXScaleLayout swizzle requires
|
||||
# hidden_size to be a multiple of 256 (SCALE_K = hidden_size / 32
|
||||
# must be divisible by 8). Pad hidden_size for weight/scale
|
||||
# allocation; the original value is preserved in unpadded_hidden_size.
|
||||
# Only applies to the native (non-emulated) CK path on GFX950.
|
||||
if (
|
||||
self.model_type == "gpt_oss"
|
||||
and current_platform.is_rocm()
|
||||
and not self.emulate
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
2 * intermediate_size_per_partition,
|
||||
self.get_packed_dim(hidden_size, self.weight_dtype),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -849,12 +820,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
self.get_packed_dim(
|
||||
intermediate_size_per_partition_after_pad, self.weight_dtype
|
||||
),
|
||||
self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -867,7 +836,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // OCP_MX_BLOCK_SIZE,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -877,7 +846,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
torch.ones(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // OCP_MX_BLOCK_SIZE,
|
||||
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -892,7 +861,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1072,6 +1041,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=layer.moe_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
@@ -1204,6 +1174,8 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
|
||||
assert self.moe.hidden_dim_unpadded is not None
|
||||
assert self.moe.intermediate_size_per_partition_unpadded is not None
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
@@ -1215,8 +1187,8 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
unpadded_N_w1=self.intermediate_size_per_partition * 2,
|
||||
unpadded_K_w1=self.unpadded_hidden_size,
|
||||
unpadded_N_w2=self.unpadded_hidden_size,
|
||||
unpadded_K_w2=self.intermediate_size_per_partition,
|
||||
unpadded_N_w1=self.moe.intermediate_size_per_partition_unpadded * 2,
|
||||
unpadded_K_w1=self.moe.hidden_dim_unpadded,
|
||||
unpadded_N_w2=self.moe.hidden_dim_unpadded,
|
||||
unpadded_K_w2=self.moe.intermediate_size_per_partition_unpadded,
|
||||
)
|
||||
|
||||
@@ -254,7 +254,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
w13, w13_scale, w2, w2_scale, is_act_and_mul, min_alignment
|
||||
)
|
||||
)
|
||||
layer.intermediate_size_per_partition = padded_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = padded_intermediate
|
||||
|
||||
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
|
||||
@@ -439,7 +439,6 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
layer.moe_config.is_act_and_mul,
|
||||
min_alignment,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
# FI kernels require W31 layout rather than W13.
|
||||
|
||||
Reference in New Issue
Block a user