[Refactor] Move FusedMoE hidden_size roundup to quant_method (#34285)

Signed-off-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
Bowen Bao
2026-03-26 23:38:26 -07:00
committed by GitHub
parent c2b17d71af
commit 0ae89f18fd
12 changed files with 204 additions and 222 deletions

View File

@@ -229,9 +229,6 @@ class FusedMoEQuantConfig:
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
hidden_pad: int = 0
intermediate_pad: int = 0
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
@@ -1172,6 +1169,11 @@ class FusedMoEConfig:
# Defaults to in_dtype if not specified.
router_logits_dtype: torch.dtype | None = None
# Defaults to hidden_dim if not specified.
hidden_dim_unpadded: int | None = None
# Defaults to intermediate_size_per_partition if not specified.
intermediate_size_per_partition_unpadded: int | None = None
moe_backend: str = "auto"
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
@@ -1195,6 +1197,13 @@ class FusedMoEConfig:
if self.router_logits_dtype is None:
self.router_logits_dtype = self.in_dtype
if self.hidden_dim_unpadded is None:
self.hidden_dim_unpadded = self.hidden_dim
if self.intermediate_size_per_partition_unpadded is None:
self.intermediate_size_per_partition_unpadded = (
self.intermediate_size_per_partition
)
@property
def tp_size(self):
return self.moe_parallel_config.tp_size

View File

@@ -9,6 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@@ -65,6 +66,38 @@ class FusedMoEMethodBase(QuantizeMethodBase):
"""
return False
def maybe_roundup_sizes(
self,
hidden_size: int,
intermediate_size_per_partition: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
) -> tuple[int, int]:
"""
Given layer hidden size and intermediate size per partition and MoE
configurations, round up hidden_size and intermediate_size_per_partition
if necessary.
Args:
hidden_size: Layer hidden-size
intermediate_size_per_partition: Intermediate size per partition for
the layer.
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.
Return:
A tuple of (rounded_hidden_size, rounded_intermediate_size_per_partition),
where:
- rounded_hidden_size is the possibly rounded up hidden size.
- rounded_intermediate_size_per_partition is the possibly rounded
up intermediate size per partition.
"""
from .all2all_utils import maybe_roundup_layer_hidden_size
return maybe_roundup_layer_hidden_size(
hidden_size, act_dtype, moe_parallel_config
), intermediate_size_per_partition
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,

View File

@@ -428,13 +428,9 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check: when weights are padded (e.g. hidden_size padded for
# GFX950 swizzle), unpadded_K_w1 carries the original dimension.
expected_K_w1 = unpadded_K_w1 if unpadded_K_w1 is not None else w1.shape[-2]
assert hidden_states.shape[-1] == expected_K_w1, (
f"hidden_states K={hidden_states.shape[-1]} != "
f"expected K={expected_K_w1} (w1 K={w1.shape[-2]})"
)
# Shape check: weights are padded (e.g. hidden_size padded for
# GFX950 swizzle).
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
E, _, N = w1.shape
@@ -494,12 +490,6 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
unpadded_K=unpadded_K_w2,
)
# When hidden_size was padded for alignment (e.g. GFX950 swizzle),
# the kernel output has the padded dimension. Slice back to the
# original hidden_size so downstream layers see the expected shape.
if unpadded_N_w2 is not None and intermediate_cache3.shape[-1] != unpadded_N_w2:
intermediate_cache3 = intermediate_cache3[..., :unpadded_N_w2].contiguous()
return intermediate_cache3

View File

@@ -210,42 +210,6 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
)
# TODO(rob): move this down to the kernel.
def maybe_roundup_hidden_size(
hidden_size: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool,
model_type: str | None,
) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
if necessary.
Args:
hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.
is_lora_enabled: True if the engine is enabled with LoRA. This
is used in the case of mxfp4 quantization in selecting the
MxFP4Backend.
model_type: for checking if gpt-oss
Return:
Rounded up hidden_size if rounding up is required based on the configs.
Original hidden size otherwise.
"""
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_roundup_layer_hidden_size,
)
hidden_size = maybe_roundup_layer_hidden_size(
hidden_size, act_dtype, moe_parallel_config
)
return hidden_size
# --8<-- [start:fused_moe]
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
@@ -459,7 +423,7 @@ class FusedMoE(CustomOp):
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
@@ -501,28 +465,13 @@ class FusedMoE(CustomOp):
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
unpadded_hidden_size = hidden_size
self.model_type = (
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
)
hidden_size = maybe_roundup_hidden_size(
hidden_size=hidden_size,
act_dtype=moe_in_dtype,
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=self.model_type,
)
self.hidden_size = hidden_size
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
hidden_dim_unpadded=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
intermediate_size_per_partition_unpadded=intermediate_size_per_partition,
num_local_experts=self.local_num_experts,
num_logical_experts=self.logical_num_experts,
moe_parallel_config=self.moe_parallel_config,
@@ -567,13 +516,6 @@ class FusedMoE(CustomOp):
# for heuristic purposes, so it must be initialized first.
self.quant_method: FusedMoEMethodBase = _get_quant_method()
# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
# and intermediate_size in moe_config during __init__. Sync
# self.hidden_size so downstream consumers (e.g. LoRA) see the
# padded value.
if self.moe_config.hidden_dim != self.hidden_size:
self.hidden_size = self.moe_config.hidden_dim
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
@@ -591,11 +533,24 @@ class FusedMoE(CustomOp):
f"EPLB is not supported {self.quant_method.__class__.__name__}."
)
# Round up hidden size and update moe_config.
hidden_size, intermediate_size_per_partition = (
self.quant_method.maybe_roundup_sizes(
hidden_size,
intermediate_size_per_partition,
moe_in_dtype,
self.moe_parallel_config,
)
)
self.moe_config.hidden_dim = hidden_size
self.moe_config.intermediate_size_per_partition = (
intermediate_size_per_partition
)
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": self.hidden_size,
"unpadded_hidden_size": unpadded_hidden_size,
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"hidden_size": hidden_size,
"intermediate_size_per_partition": intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
"global_num_experts": self.global_num_experts,
@@ -933,9 +888,17 @@ class FusedMoE(CustomOp):
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
# and we're not loading the full weight
if not load_full and loaded_weight.ndim > 0:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# Handle padding: loaded_weight might be smaller than shard_size on last
# TP rank
start_offset = shard_size * tp_rank
available = loaded_weight.shape[shard_dim] - start_offset
if available <= 0:
# If there is no available weight to load for this TP rank
# (can happen on last TP rank with padding), we can skip
# loading and return early
return
narrow_size = min(shard_size, available)
loaded_weight = loaded_weight.narrow(shard_dim, start_offset, narrow_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
@@ -944,6 +907,13 @@ class FusedMoE(CustomOp):
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
# Handle padding: if loaded_weight is smaller than expert_data (can happen
# on last TP shard with padding), copy to top-left corner
if expert_data.shape != loaded_weight.shape:
expert_data = expert_data[
: loaded_weight.shape[0], : loaded_weight.shape[1]
]
expert_data.copy_(loaded_weight)
def _load_w2(
@@ -961,10 +931,24 @@ class FusedMoE(CustomOp):
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
# and we're not loading the full weight
if not load_full and loaded_weight.ndim > 0:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# Handle padding: loaded_weight might be smaller than shard_size on last
# TP rank
start_offset = shard_size * tp_rank
available = loaded_weight.shape[shard_dim] - start_offset
if available <= 0:
# If there is no available weight to load for this TP rank
# (can happen on last TP rank with padding), we can skip
# loading and return early
return
narrow_size = min(shard_size, available)
loaded_weight = loaded_weight.narrow(shard_dim, start_offset, narrow_size)
# w2, down_proj: Load into only logical weight of w2.
# Handle padding: if loaded_weight is smaller than expert_data (can happen
# on last TP shard with padding), copy to top-left corner
if expert_data.shape != loaded_weight.shape:
expert_data = expert_data[
: loaded_weight.shape[0], : loaded_weight.shape[1]
]
expert_data.copy_(loaded_weight)
def _load_single_value(
@@ -1549,6 +1533,14 @@ class FusedMoE(CustomOp):
]
]
@property
def hidden_size(self) -> int:
return self.moe_config.hidden_dim
@property
def intermediate_size_per_partition(self) -> int:
return self.moe_config.intermediate_size_per_partition
def extra_repr(self) -> str:
s = (
f"global_num_experts={self.global_num_experts}, "

View File

@@ -779,8 +779,6 @@ def make_mxfp4_moe_quant_config(
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in (
@@ -802,16 +800,12 @@ def make_mxfp4_moe_quant_config(
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.CK,
):
config = mxfp4_w4a16_moe_quant_config(
return mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
if mxfp4_backend == Mxfp4MoeBackend.CK:
config.hidden_pad = hidden_pad
config.intermediate_pad = intermediate_pad
return config
else:
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",

View File

@@ -10,6 +10,7 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
@@ -186,6 +187,7 @@ def rocm_aiter_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
moe_config: FusedMoEConfig,
activation: MoEActivation = MoEActivation.SILU,
apply_router_weight_on_input: bool = False,
expert_map: torch.Tensor | None = None,
@@ -276,6 +278,17 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
# Compute padding on-the-fly for CK MXFP4 kernels
hidden_pad = 0
intermediate_pad = 0
assert moe_config.hidden_dim_unpadded is not None
assert moe_config.intermediate_size_per_partition_unpadded is not None
hidden_pad = hidden_states.shape[1] - moe_config.hidden_dim_unpadded
intermediate_pad = (
moe_config.intermediate_size_per_partition
- moe_config.intermediate_size_per_partition_unpadded
)
return rocm_aiter_ops.fused_moe(
hidden_states,
w1,
@@ -292,8 +305,8 @@ def rocm_aiter_fused_experts(
doweight_stage1=apply_router_weight_on_input,
num_local_tokens=num_local_tokens,
output_dtype=output_dtype,
hidden_pad=quant_config.hidden_pad,
intermediate_pad=quant_config.intermediate_pad,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
)
@@ -419,6 +432,7 @@ class AiterExperts(mk.FusedMoEExpertsModular):
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.quant_config,
moe_config=self.moe_config,
a1q_scale=a1q_scale,
num_local_tokens=num_local_tokens,
output_dtype=output.dtype,

View File

@@ -715,8 +715,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
@@ -2274,8 +2272,6 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None

View File

@@ -672,8 +672,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
@@ -1011,8 +1009,6 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None

View File

@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
@@ -107,18 +108,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
self.moe_kernel: mk.FusedMoEKernel | None = None
# Round up dims once based on backend. This mutates the shared
# FusedMoEConfig in-place so that create_weights() and all
# downstream code see the padded dimensions. This must happen
# before create_weights() is called.
self.moe.hidden_dim, self.moe.intermediate_size_per_partition = (
mxfp4_round_up_hidden_size_and_intermediate_size(
self.mxfp4_backend,
self.moe.hidden_dim,
self.moe.intermediate_size_per_partition,
)
)
# Used for triton kernel precision configs
self.w13_precision_config = None
self.w2_precision_config = None
@@ -129,6 +118,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# so can skip the padding in the forward before applying the moe method
return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
def maybe_roundup_sizes(
self,
hidden_size: int,
intermediate_size_per_partition: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
) -> tuple[int, int]:
hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
hidden_size=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
act_dtype=act_dtype,
moe_parallel_config=moe_parallel_config,
)
return mxfp4_round_up_hidden_size_and_intermediate_size(
self.mxfp4_backend, hidden_size, intermediate_size_per_partition
)
def create_weights(
self,
layer: torch.nn.Module,
@@ -143,32 +149,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
scale_dtype = torch.uint8
mxfp4_block = 32
# Use pre-rounded sizes from config
self.intermediate_size = intermediate_size_per_partition_after_pad = (
self.moe.intermediate_size_per_partition
)
self.hidden_size = hidden_size = self.moe.hidden_dim
# Expose padded dimensions on the layer for LoRA and Marlin code
# that reads layer.hidden_size / layer.intermediate_size_per_partition.
layer.params_dtype = params_dtype
layer.num_experts = num_experts
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = (
intermediate_size_per_partition_after_pad
)
# CK (gfx950) padding info for rocm_aiter_ops.fused_moe()
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
self.intermediate_pad = (
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
)
self.intermediate_size = intermediate_size_per_partition
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
2 * intermediate_size_per_partition,
hidden_size // 2,
dtype=weight_dtype,
),
@@ -180,7 +170,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
2 * intermediate_size_per_partition,
hidden_size // mxfp4_block,
dtype=scale_dtype,
),
@@ -194,7 +184,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // 2,
intermediate_size_per_partition // 2,
dtype=weight_dtype,
),
requires_grad=False,
@@ -206,7 +196,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // mxfp4_block,
intermediate_size_per_partition // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
@@ -218,7 +208,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
2 * intermediate_size_per_partition,
dtype=torch.bfloat16,
),
requires_grad=False,
@@ -368,8 +358,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
hidden_pad=self.hidden_pad,
intermediate_pad=self.intermediate_pad,
)
def select_gemm_impl(

View File

@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe import (
MoEActivation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
mxfp4_w4a8_moe_quant_config,
@@ -27,13 +28,13 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
CK_MXFP4_MOE_DIM_ALIGNMENT,
_swizzle_mxfp4,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
@@ -49,7 +50,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
@@ -173,8 +173,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
@@ -182,7 +180,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
@@ -194,7 +192,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition,
@@ -461,6 +459,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config,
moe_config=layer.moe_config,
expert_map=layer.expert_map,
)
elif self.use_marlin:
@@ -527,7 +526,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
):
params_dtype = torch.uint32
w13_weight = torch.nn.Parameter(
torch.empty(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 8, # INT32 packing for W4
@@ -536,7 +535,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // 8, # INT32 packing for W4
@@ -649,6 +648,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config,
moe_config=layer.moe_config,
expert_map=layer.expert_map,
)
@@ -702,6 +702,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.mxfp4_backend: Mxfp4MoeBackend | None = None
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe)
elif self.ocp_mx_scheme.startswith("w_mxfp4"):
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
self.mxfp4_backend = Mxfp4MoeBackend.NONE
if self.input_quant is not None:
self.static_input_scales = not self.input_quant.get("is_dynamic")
@@ -734,36 +737,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
# alignment requirements. When violated (e.g. MiniMax-M2.1 with
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
# "device_gemm ... does not support this GEMM problem".
# Fall back to emulation in that case.
# For gpt_oss models, create_weights rounds up the dimensions
# internally, so the alignment check is skipped.
if (
not self.emulate
and self.use_rocm_aiter_moe
and self.ocp_mx_scheme is not None
and self.ocp_mx_scheme.startswith("w_mxfp4")
and self.model_type != "gpt_oss"
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
):
logger.warning_once(
"AITER CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of %d). "
"This typically happens when intermediate_size / "
"tensor_parallel_size produces an incompatible dimension. "
"Falling back to emulation mode. To avoid this overhead, "
"use a compatible tensor_parallel_size or set "
"VLLM_ROCM_USE_AITER_MOE=0.",
moe.intermediate_size_per_partition,
CK_MXFP4_MOE_DIM_ALIGNMENT,
)
self.use_rocm_aiter_moe = False
self.emulate = True
) and (
self.mxfp4_backend is None
or self.mxfp4_backend is Mxfp4MoeBackend.NONE
or not self.use_rocm_aiter_moe
)
if self.emulate:
logger.warning_once(
@@ -780,6 +758,27 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"The current mode supports native MoE MXFP4 computation"
)
def maybe_roundup_sizes(
self,
hidden_size: int,
intermediate_size_per_partition: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
) -> tuple[int, int]:
hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
hidden_size=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
act_dtype=act_dtype,
moe_parallel_config=moe_parallel_config,
)
if self.mxfp4_backend is not None:
hidden_size, intermediate_size_per_partition = (
mxfp4_round_up_hidden_size_and_intermediate_size(
self.mxfp4_backend, hidden_size, intermediate_size_per_partition
)
)
return hidden_size, intermediate_size_per_partition
def get_packed_dim(self, dim: int, quant_dtype: str):
if quant_dtype == "mxfp4":
assert dim % 2 == 0
@@ -805,40 +804,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
params_dtype = torch.uint8
self.intermediate_size_per_partition = intermediate_size_per_partition
if self.model_type == "gpt_oss":
if current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256
)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64
)
else:
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
self.unpadded_hidden_size = extra_weight_attrs.get(
"unpadded_hidden_size", hidden_size
)
# On GFX950, the GFX950MXScaleLayout swizzle requires
# hidden_size to be a multiple of 256 (SCALE_K = hidden_size / 32
# must be divisible by 8). Pad hidden_size for weight/scale
# allocation; the original value is preserved in unpadded_hidden_size.
# Only applies to the native (non-emulated) CK path on GFX950.
if (
self.model_type == "gpt_oss"
and current_platform.is_rocm()
and not self.emulate
):
hidden_size = round_up(hidden_size, 256)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
2 * intermediate_size_per_partition,
self.get_packed_dim(hidden_size, self.weight_dtype),
dtype=params_dtype,
),
@@ -849,12 +820,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
torch.zeros(
num_experts,
hidden_size,
self.get_packed_dim(
intermediate_size_per_partition_after_pad, self.weight_dtype
),
self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype),
dtype=params_dtype,
),
requires_grad=False,
@@ -867,7 +836,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition_after_pad,
2 * intermediate_size_per_partition,
hidden_size // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
@@ -877,7 +846,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
torch.ones(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // OCP_MX_BLOCK_SIZE,
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
@@ -892,7 +861,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
@@ -1072,6 +1041,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
moe_config=layer.moe_config,
expert_map=layer.expert_map,
)
else:
@@ -1204,6 +1174,8 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
triton_kernel_moe_forward,
)
assert self.moe.hidden_dim_unpadded is not None
assert self.moe.intermediate_size_per_partition_unpadded is not None
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
@@ -1215,8 +1187,8 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
expert_map=expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
unpadded_N_w1=self.intermediate_size_per_partition * 2,
unpadded_K_w1=self.unpadded_hidden_size,
unpadded_N_w2=self.unpadded_hidden_size,
unpadded_K_w2=self.intermediate_size_per_partition,
unpadded_N_w1=self.moe.intermediate_size_per_partition_unpadded * 2,
unpadded_K_w1=self.moe.hidden_dim_unpadded,
unpadded_N_w2=self.moe.hidden_dim_unpadded,
unpadded_K_w2=self.moe.intermediate_size_per_partition_unpadded,
)

View File

@@ -254,7 +254,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
w13, w13_scale, w2, w2_scale, is_act_and_mul, min_alignment
)
)
layer.intermediate_size_per_partition = padded_intermediate
layer.moe_config.intermediate_size_per_partition = padded_intermediate
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(

View File

@@ -439,7 +439,6 @@ def prepare_fp8_moe_layer_for_fi(
layer.moe_config.is_act_and_mul,
min_alignment,
)
layer.intermediate_size_per_partition = new_intermediate
layer.moe_config.intermediate_size_per_partition = new_intermediate
# FI kernels require W31 layout rather than W13.