[Bugfix][ROCm][MoE] Fix mxfp4 oracle regressions from #37128 (#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-24 19:17:33 -05:00
committed by GitHub
parent 8bbb7c7f20
commit 679c6a3ecc
11 changed files with 69 additions and 15 deletions

View File

@@ -229,6 +229,9 @@ class FusedMoEQuantConfig:
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
hidden_pad: int = 0
intermediate_pad: int = 0
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (

View File

@@ -257,7 +257,7 @@ def triton_kernel_moe_forward(
# sparse_logits.indx contains global expert IDs remap to local.
topk_ids = expert_map[sparse_logits.indx.to(torch.long)]
topk_weights = sparse_logits.vals
local_num_experts = w1.size(0)
local_num_experts = w1.shape[0]
routing_data, gather_idx, scatter_idx = make_routing_data(
topk_ids, topk_weights, local_num_experts
)
@@ -604,8 +604,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, _, N = w1.size()
assert len(w1.shape) == 3 and len(w2.shape) == 3
E, _, N = w1.shape
K = a1.size(-1)
assert a1.dim() == 2
@@ -683,7 +683,7 @@ class OAITritonExperts(BaseOAITritonExperts):
if expert_map is not None:
topk_ids = expert_map[topk_ids]
local_num_experts = w1.size(0)
local_num_experts = w1.shape[0]
if global_num_experts == -1:
global_num_experts = local_num_experts
@@ -781,7 +781,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
if expert_map is not None:
topk_ids = expert_map[topk_ids]
local_num_experts = w1.size(0)
local_num_experts = w1.shape[0]
if global_num_experts == -1:
global_num_experts = local_num_experts

View File

@@ -567,6 +567,13 @@ class FusedMoE(CustomOp):
# for heuristic purposes, so it must be initialized first.
self.quant_method: FusedMoEMethodBase = _get_quant_method()
# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
# and intermediate_size in moe_config during __init__. Sync
# self.hidden_size so downstream consumers (e.g. LoRA) see the
# padded value.
if self.moe_config.hidden_dim != self.hidden_size:
self.hidden_size = self.moe_config.hidden_dim
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
@@ -586,7 +593,7 @@ class FusedMoE(CustomOp):
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"hidden_size": self.hidden_size,
"unpadded_hidden_size": unpadded_hidden_size,
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"params_dtype": params_dtype,

View File

@@ -768,8 +768,8 @@ class FusedMoEExpertsModular(FusedMoEExperts):
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
assert len(w1.shape) == 3 and len(w2.shape) == 3
E, N, _ = w1.shape
K = a1.size(-1)
if a1.dim() == 2:
@@ -1349,7 +1349,7 @@ class FusedMoEKernelModularImpl:
else:
output = torch.empty_like(hidden_states)
local_num_experts = w1.size(0)
local_num_experts = w1.shape[0]
if global_num_experts == -1:
global_num_experts = local_num_experts

View File

@@ -212,7 +212,11 @@ def select_mxfp4_moe_backend(
# LoRA: separate experts backend path
if config.is_lora_enabled:
if not current_platform.is_cuda():
raise NotImplementedError("Mxfp4 LoRA only supported on CUDA Platform.")
# ROCm: Triton mxfp4 LoRA hits GPU memory faults due to
# triton_kernels.tensor.Tensor / HIP read-only page issues
# during weight swizzle and LoRA forward. Needs work from
# the triton_kernels/aiter side.
raise NotImplementedError("Mxfp4 LoRA is currently only supported on CUDA.")
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("Using Triton backend for mxfp4 lora")
return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls(
@@ -775,6 +779,8 @@ def make_mxfp4_moe_quant_config(
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in (
@@ -796,12 +802,16 @@ def make_mxfp4_moe_quant_config(
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.CK,
):
return mxfp4_w4a16_moe_quant_config(
config = mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
if mxfp4_backend == Mxfp4MoeBackend.CK:
config.hidden_pad = hidden_pad
config.intermediate_pad = intermediate_pad
return config
else:
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",

View File

@@ -292,6 +292,8 @@ def rocm_aiter_fused_experts(
doweight_stage1=apply_router_weight_on_input,
num_local_tokens=num_local_tokens,
output_dtype=output_dtype,
hidden_pad=quant_config.hidden_pad,
intermediate_pad=quant_config.intermediate_pad,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
)
@@ -332,7 +334,15 @@ class AiterExperts(mk.FusedMoEExpertsModular):
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
if (weight_key, activation_key) not in SUPPORTED_W_A:
return False
# CK MXFP4 MoE kernels are only supported on gfx950.
if weight_key == kMxfp4Static:
from vllm.platforms.rocm import on_gfx950
if not on_gfx950():
return False
return True
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:

View File

@@ -158,6 +158,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad
)
# CK (gfx950) padding info for rocm_aiter_ops.fused_moe()
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
self.intermediate_pad = (
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
)
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
@@ -362,6 +368,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
hidden_pad=self.hidden_pad,
intermediate_pad=self.intermediate_pad,
)
def select_gemm_impl(