diff --git a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml new file mode 100644 index 000000000..850a6d28b --- /dev/null +++ b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16 +metric_threshold: 0.568 +reasoning_effort: low +server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter" +env: + VLLM_ROCM_USE_AITER: "1" diff --git a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml new file mode 100644 index 000000000..903f30e59 --- /dev/null +++ b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16 +metric_threshold: 0.568 +reasoning_effort: low +server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton" \ No newline at end of file diff --git a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-mxfp4-fp8.yaml b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml similarity index 100% rename from tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-mxfp4-fp8.yaml rename to tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml diff --git a/tests/evals/gpt_oss/configs/models-gfx950.txt b/tests/evals/gpt_oss/configs/models-gfx950.txt index 5085aa9f2..d25f4e3a5 100644 --- a/tests/evals/gpt_oss/configs/models-gfx950.txt +++ b/tests/evals/gpt_oss/configs/models-gfx950.txt @@ -1,4 +1,6 @@ # GFX950 model configurations for GPQA evaluation # Tests different environment variable combinations gpt-oss-20b-rocm-baseline.yaml -gpt-oss-20b-rocm-mxfp4-fp8.yaml \ No newline at end of file +gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml +gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml +gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 7249b425f..d4a0817e0 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -54,8 +54,8 @@ class Mxfp4MoeBackend(Enum): # Marlin BATCHED_MARLIN = "BATCHED_MARLIN" MARLIN = "MARLIN" - # ROCm AITER (CK) - CK = "CK" + # ROCm AITER + AITER = "AITER" # Triton TRITON = "TRITON" TRITON_UNFUSED = "TRITON_UNFUSED" @@ -130,7 +130,7 @@ def backend_to_kernel_cls( return [BatchedMarlinExperts] - elif backend == Mxfp4MoeBackend.CK: + elif backend == Mxfp4MoeBackend.AITER: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( AiterExperts, ) @@ -155,7 +155,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend: "flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, "triton": Mxfp4MoeBackend.TRITON, "marlin": Mxfp4MoeBackend.MARLIN, - "ck": Mxfp4MoeBackend.CK, + "aiter": Mxfp4MoeBackend.AITER, "xpu": Mxfp4MoeBackend.XPU, } if backend := mapping.get(runner_backend): @@ -173,7 +173,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: """ _AVAILABLE_BACKENDS = [ Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, - Mxfp4MoeBackend.CK, + Mxfp4MoeBackend.AITER, Mxfp4MoeBackend.TRITON, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, Mxfp4MoeBackend.TRITON_UNFUSED, @@ -656,7 +656,7 @@ def convert_to_mxfp4_moe_kernel_format( w2_bias, ) - elif mxfp4_backend == Mxfp4MoeBackend.CK: + elif mxfp4_backend == Mxfp4MoeBackend.AITER: from vllm._aiter_ops import rocm_aiter_ops if w13_bias is not None: @@ -794,7 +794,7 @@ def make_mxfp4_moe_quant_config( Mxfp4MoeBackend.TRITON_UNFUSED, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - Mxfp4MoeBackend.CK, + Mxfp4MoeBackend.AITER, ): return mxfp4_w4a16_moe_quant_config( w1_bias=w1_bias, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index c48e49fe8..3f7ddbfd7 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -5,6 +5,7 @@ from typing import Any import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm import envs from vllm._aiter_ops import rocm_aiter_ops @@ -27,7 +28,11 @@ from vllm.model_executor.layers.fused_moe.config import ( ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + TRITON_BACKENDS, Mxfp4MoeBackend, + convert_to_mxfp4_moe_kernel_format, + make_mxfp4_moe_kernel, + make_mxfp4_moe_quant_config, mxfp4_round_up_hidden_size_and_intermediate_size, select_mxfp4_moe_backend, ) @@ -47,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, ) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -699,9 +704,16 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): f"Please check that the combination is supported in OCP_MX_Scheme." ) - self.mxfp4_backend: Mxfp4MoeBackend | None = None + self.mxfp4_backend: Mxfp4MoeBackend = Mxfp4MoeBackend.NONE + self.experts_cls: type[mk.FusedMoEExperts] | None = None + self.moe_kernel: mk.FusedMoEKernel | None = None + + # Used for triton kernel precision configs + self.w13_precision_config = None + self.w2_precision_config = None + if self.ocp_mx_scheme == "w_mxfp4": - self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe) + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) elif self.ocp_mx_scheme.startswith("w_mxfp4"): # TODO(bowenbao): refactor and introduce backends for other OCP MX schemes. self.mxfp4_backend = Mxfp4MoeBackend.NONE @@ -738,9 +750,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): not current_platform.supports_mx() or not self.ocp_mx_scheme.startswith("w_mxfp4") ) and ( - self.mxfp4_backend is None - or self.mxfp4_backend is Mxfp4MoeBackend.NONE - or not self.use_rocm_aiter_moe + self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe ) if self.emulate: @@ -944,11 +954,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): w2_input_scale, requires_grad=False ) - # secondly, process mxfp weights + # For w_mxfp4, use oracle functions + if ( + self.ocp_mx_scheme == "w_mxfp4" + and self.mxfp4_backend != Mxfp4MoeBackend.NONE + ): + self._setup_kernel_via_oracle(layer) + return + + # TODO(bowenbao): gradually migrate to oracles. + # secondly, process mxfp weights for other schemes if self.emulate: + # Build quant config for emulation path + self.moe_quant_config = self.get_fused_moe_quant_config(layer) torch.accelerator.empty_cache() return + # Existing AITER path for w_mxfp4_a_mxfp4 and other schemes from aiter.utility.fp4_utils import e8m0_shuffle # Pre-shuffle weight scales @@ -980,11 +1002,87 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w13_weight.is_shuffled = True layer.w2_weight.is_shuffled = True + + # Build quant config for AITER path + self.moe_quant_config = self.get_fused_moe_quant_config(layer) torch.accelerator.empty_cache() + def _setup_kernel_via_oracle(self, layer: FusedMoE): + """Setup kernel using oracle functions for w_mxfp4 scheme.""" + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w13_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + + # Convert weights to kernel format + w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( + convert_to_mxfp4_moe_kernel_format( + mxfp4_backend=self.mxfp4_backend, + layer=layer, + w13_weight=w13, + w2_weight=w2, + w13_weight_scale=w13_scale, + w2_weight_scale=w2_scale, + w13_bias=w13_bias, + w2_bias=w2_bias, + ) + ) + + # For TRITON backends, weights are wrapped tensors from triton_kernels + # that don't support .detach(). Manually assign parameters. + if self.mxfp4_backend not in TRITON_BACKENDS: + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight_scale", w2_scale) + else: + layer.w13_weight = w13 + layer.w2_weight = w2 + self.w13_precision_config = w13_scale + self.w2_precision_config = w2_scale + + if w13_bias is not None and w2_bias is not None: + replace_parameter(layer, "w13_bias", w13_bias) + replace_parameter(layer, "w2_bias", w2_bias) + + # Build quant config and kernel + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config is not None and self.experts_cls is not None: + self.moe_kernel = make_mxfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + mxfp4_backend=self.mxfp4_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, + ) + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: + # For w_mxfp4 with oracle backend, use oracle function + if ( + self.ocp_mx_scheme == "w_mxfp4" + and self.mxfp4_backend != Mxfp4MoeBackend.NONE + ): + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + if self.mxfp4_backend in TRITON_BACKENDS: + w1_scale = self.w13_precision_config + w2_scale = self.w2_precision_config + return make_mxfp4_moe_quant_config( + mxfp4_backend=self.mxfp4_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=getattr(layer, "w13_bias", None), + w2_bias=getattr(layer, "w2_bias", None), + ) + + # Existing code for other schemes + # TODO(bowenbao): kept for emulation fallback, to be refactored into + # dedicated emulation backend. if self.ocp_mx_scheme == "w_mxfp4": return mxfp4_w4a16_moe_quant_config( w1_scale=layer.w13_weight_scale, @@ -1020,6 +1118,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): block_shape=None, ) + @property + def is_monolithic(self) -> bool: + if self.moe_kernel is not None: + return self.moe_kernel.is_monolithic + return False + def apply( self, layer: FusedMoE, @@ -1028,6 +1132,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: + # For w_mxfp4 with oracle kernel + if self.moe_kernel is not None: + return self.moe_kernel.apply( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + expert_map=layer.expert_map, + shared_experts_input=shared_experts_input, + ) + + # Existing code for emulation/AITER paths if not self.emulate: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, @@ -1061,6 +1181,25 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): quant_config=self.moe_quant_config, ) + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor: + assert self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + router_logits=router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) + class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod): def __init__(