[ROCm][Quantization][1/N] Refactor quark_moe w_mxfp4 w/ oracle (#38774)

Signed-off-by: Bowen Bao <bowenbao@amd.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Bowen Bao
2026-04-02 20:29:57 -07:00
committed by GitHub
parent 32e0c0bfa2
commit 103f0de565
6 changed files with 170 additions and 15 deletions

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter"
env:
VLLM_ROCM_USE_AITER: "1"

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton"

View File

@@ -1,4 +1,6 @@
# GFX950 model configurations for GPQA evaluation
# Tests different environment variable combinations
gpt-oss-20b-rocm-baseline.yaml
gpt-oss-20b-rocm-mxfp4-fp8.yaml
gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml

View File

@@ -54,8 +54,8 @@ class Mxfp4MoeBackend(Enum):
# Marlin
BATCHED_MARLIN = "BATCHED_MARLIN"
MARLIN = "MARLIN"
# ROCm AITER (CK)
CK = "CK"
# ROCm AITER
AITER = "AITER"
# Triton
TRITON = "TRITON"
TRITON_UNFUSED = "TRITON_UNFUSED"
@@ -130,7 +130,7 @@ def backend_to_kernel_cls(
return [BatchedMarlinExperts]
elif backend == Mxfp4MoeBackend.CK:
elif backend == Mxfp4MoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
@@ -155,7 +155,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
"triton": Mxfp4MoeBackend.TRITON,
"marlin": Mxfp4MoeBackend.MARLIN,
"ck": Mxfp4MoeBackend.CK,
"aiter": Mxfp4MoeBackend.AITER,
"xpu": Mxfp4MoeBackend.XPU,
}
if backend := mapping.get(runner_backend):
@@ -173,7 +173,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
"""
_AVAILABLE_BACKENDS = [
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.CK,
Mxfp4MoeBackend.AITER,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.TRITON_UNFUSED,
@@ -656,7 +656,7 @@ def convert_to_mxfp4_moe_kernel_format(
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.CK:
elif mxfp4_backend == Mxfp4MoeBackend.AITER:
from vllm._aiter_ops import rocm_aiter_ops
if w13_bias is not None:
@@ -794,7 +794,7 @@ def make_mxfp4_moe_quant_config(
Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.CK,
Mxfp4MoeBackend.AITER,
):
return mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias,

View File

@@ -5,6 +5,7 @@ from typing import Any
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
@@ -27,7 +28,11 @@ from vllm.model_executor.layers.fused_moe.config import (
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS,
Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
)
@@ -47,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@@ -699,9 +704,16 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
f"Please check that the combination is supported in OCP_MX_Scheme."
)
self.mxfp4_backend: Mxfp4MoeBackend | None = None
self.mxfp4_backend: Mxfp4MoeBackend = Mxfp4MoeBackend.NONE
self.experts_cls: type[mk.FusedMoEExperts] | None = None
self.moe_kernel: mk.FusedMoEKernel | None = None
# Used for triton kernel precision configs
self.w13_precision_config = None
self.w2_precision_config = None
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe)
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
elif self.ocp_mx_scheme.startswith("w_mxfp4"):
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
self.mxfp4_backend = Mxfp4MoeBackend.NONE
@@ -738,9 +750,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (
self.mxfp4_backend is None
or self.mxfp4_backend is Mxfp4MoeBackend.NONE
or not self.use_rocm_aiter_moe
self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe
)
if self.emulate:
@@ -944,11 +954,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w2_input_scale, requires_grad=False
)
# secondly, process mxfp weights
# For w_mxfp4, use oracle functions
if (
self.ocp_mx_scheme == "w_mxfp4"
and self.mxfp4_backend != Mxfp4MoeBackend.NONE
):
self._setup_kernel_via_oracle(layer)
return
# TODO(bowenbao): gradually migrate to oracles.
# secondly, process mxfp weights for other schemes
if self.emulate:
# Build quant config for emulation path
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
torch.accelerator.empty_cache()
return
# Existing AITER path for w_mxfp4_a_mxfp4 and other schemes
from aiter.utility.fp4_utils import e8m0_shuffle
# Pre-shuffle weight scales
@@ -980,11 +1002,87 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True
# Build quant config for AITER path
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
torch.accelerator.empty_cache()
def _setup_kernel_via_oracle(self, layer: FusedMoE):
"""Setup kernel using oracle functions for w_mxfp4 scheme."""
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
# Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend,
layer=layer,
w13_weight=w13,
w2_weight=w2,
w13_weight_scale=w13_scale,
w2_weight_scale=w2_scale,
w13_bias=w13_bias,
w2_bias=w2_bias,
)
)
# For TRITON backends, weights are wrapped tensors from triton_kernels
# that don't support .detach(). Manually assign parameters.
if self.mxfp4_backend not in TRITON_BACKENDS:
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
else:
layer.w13_weight = w13
layer.w2_weight = w2
self.w13_precision_config = w13_scale
self.w2_precision_config = w2_scale
if w13_bias is not None and w2_bias is not None:
replace_parameter(layer, "w13_bias", w13_bias)
replace_parameter(layer, "w2_bias", w2_bias)
# Build quant config and kernel
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None and self.experts_cls is not None:
self.moe_kernel = make_mxfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
mxfp4_backend=self.mxfp4_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# For w_mxfp4 with oracle backend, use oracle function
if (
self.ocp_mx_scheme == "w_mxfp4"
and self.mxfp4_backend != Mxfp4MoeBackend.NONE
):
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
if self.mxfp4_backend in TRITON_BACKENDS:
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
# Existing code for other schemes
# TODO(bowenbao): kept for emulation fallback, to be refactored into
# dedicated emulation backend.
if self.ocp_mx_scheme == "w_mxfp4":
return mxfp4_w4a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
@@ -1020,6 +1118,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
block_shape=None,
)
@property
def is_monolithic(self) -> bool:
if self.moe_kernel is not None:
return self.moe_kernel.is_monolithic
return False
def apply(
self,
layer: FusedMoE,
@@ -1028,6 +1132,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
# For w_mxfp4 with oracle kernel
if self.moe_kernel is not None:
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)
# Existing code for emulation/AITER paths
if not self.emulate:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
@@ -1061,6 +1181,25 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
quant_config=self.moe_quant_config,
)
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
router_logits=router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
def __init__(