[ROCm][Quantization][1/N] Refactor quark_moe w_mxfp4 w/ oracle (#38774)
Signed-off-by: Bowen Bao <bowenbao@amd.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: low
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter"
|
||||
env:
|
||||
VLLM_ROCM_USE_AITER: "1"
|
||||
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: low
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton"
|
||||
@@ -1,4 +1,6 @@
|
||||
# GFX950 model configurations for GPQA evaluation
|
||||
# Tests different environment variable combinations
|
||||
gpt-oss-20b-rocm-baseline.yaml
|
||||
gpt-oss-20b-rocm-mxfp4-fp8.yaml
|
||||
gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
|
||||
gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
|
||||
gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml
|
||||
|
||||
@@ -54,8 +54,8 @@ class Mxfp4MoeBackend(Enum):
|
||||
# Marlin
|
||||
BATCHED_MARLIN = "BATCHED_MARLIN"
|
||||
MARLIN = "MARLIN"
|
||||
# ROCm AITER (CK)
|
||||
CK = "CK"
|
||||
# ROCm AITER
|
||||
AITER = "AITER"
|
||||
# Triton
|
||||
TRITON = "TRITON"
|
||||
TRITON_UNFUSED = "TRITON_UNFUSED"
|
||||
@@ -130,7 +130,7 @@ def backend_to_kernel_cls(
|
||||
|
||||
return [BatchedMarlinExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.CK:
|
||||
elif backend == Mxfp4MoeBackend.AITER:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
@@ -155,7 +155,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
|
||||
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
"triton": Mxfp4MoeBackend.TRITON,
|
||||
"marlin": Mxfp4MoeBackend.MARLIN,
|
||||
"ck": Mxfp4MoeBackend.CK,
|
||||
"aiter": Mxfp4MoeBackend.AITER,
|
||||
"xpu": Mxfp4MoeBackend.XPU,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
@@ -173,7 +173,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
|
||||
"""
|
||||
_AVAILABLE_BACKENDS = [
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.CK,
|
||||
Mxfp4MoeBackend.AITER,
|
||||
Mxfp4MoeBackend.TRITON,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.TRITON_UNFUSED,
|
||||
@@ -656,7 +656,7 @@ def convert_to_mxfp4_moe_kernel_format(
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
elif mxfp4_backend == Mxfp4MoeBackend.CK:
|
||||
elif mxfp4_backend == Mxfp4MoeBackend.AITER:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if w13_bias is not None:
|
||||
@@ -794,7 +794,7 @@ def make_mxfp4_moe_quant_config(
|
||||
Mxfp4MoeBackend.TRITON_UNFUSED,
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.CK,
|
||||
Mxfp4MoeBackend.AITER,
|
||||
):
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
@@ -27,7 +28,11 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
|
||||
TRITON_BACKENDS,
|
||||
Mxfp4MoeBackend,
|
||||
convert_to_mxfp4_moe_kernel_format,
|
||||
make_mxfp4_moe_kernel,
|
||||
make_mxfp4_moe_quant_config,
|
||||
mxfp4_round_up_hidden_size_and_intermediate_size,
|
||||
select_mxfp4_moe_backend,
|
||||
)
|
||||
@@ -47,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@@ -699,9 +704,16 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
f"Please check that the combination is supported in OCP_MX_Scheme."
|
||||
)
|
||||
|
||||
self.mxfp4_backend: Mxfp4MoeBackend | None = None
|
||||
self.mxfp4_backend: Mxfp4MoeBackend = Mxfp4MoeBackend.NONE
|
||||
self.experts_cls: type[mk.FusedMoEExperts] | None = None
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
# Used for triton kernel precision configs
|
||||
self.w13_precision_config = None
|
||||
self.w2_precision_config = None
|
||||
|
||||
if self.ocp_mx_scheme == "w_mxfp4":
|
||||
self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe)
|
||||
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
|
||||
elif self.ocp_mx_scheme.startswith("w_mxfp4"):
|
||||
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
|
||||
self.mxfp4_backend = Mxfp4MoeBackend.NONE
|
||||
@@ -738,9 +750,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
not current_platform.supports_mx()
|
||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
) and (
|
||||
self.mxfp4_backend is None
|
||||
or self.mxfp4_backend is Mxfp4MoeBackend.NONE
|
||||
or not self.use_rocm_aiter_moe
|
||||
self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe
|
||||
)
|
||||
|
||||
if self.emulate:
|
||||
@@ -944,11 +954,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# secondly, process mxfp weights
|
||||
# For w_mxfp4, use oracle functions
|
||||
if (
|
||||
self.ocp_mx_scheme == "w_mxfp4"
|
||||
and self.mxfp4_backend != Mxfp4MoeBackend.NONE
|
||||
):
|
||||
self._setup_kernel_via_oracle(layer)
|
||||
return
|
||||
|
||||
# TODO(bowenbao): gradually migrate to oracles.
|
||||
# secondly, process mxfp weights for other schemes
|
||||
if self.emulate:
|
||||
# Build quant config for emulation path
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
torch.accelerator.empty_cache()
|
||||
return
|
||||
|
||||
# Existing AITER path for w_mxfp4_a_mxfp4 and other schemes
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
# Pre-shuffle weight scales
|
||||
@@ -980,11 +1002,87 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
layer.w13_weight.is_shuffled = True
|
||||
layer.w2_weight.is_shuffled = True
|
||||
|
||||
# Build quant config for AITER path
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
def _setup_kernel_via_oracle(self, layer: FusedMoE):
|
||||
"""Setup kernel using oracle functions for w_mxfp4 scheme."""
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
w13_bias = getattr(layer, "w13_bias", None)
|
||||
w2_bias = getattr(layer, "w2_bias", None)
|
||||
|
||||
# Convert weights to kernel format
|
||||
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
|
||||
convert_to_mxfp4_moe_kernel_format(
|
||||
mxfp4_backend=self.mxfp4_backend,
|
||||
layer=layer,
|
||||
w13_weight=w13,
|
||||
w2_weight=w2,
|
||||
w13_weight_scale=w13_scale,
|
||||
w2_weight_scale=w2_scale,
|
||||
w13_bias=w13_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
)
|
||||
|
||||
# For TRITON backends, weights are wrapped tensors from triton_kernels
|
||||
# that don't support .detach(). Manually assign parameters.
|
||||
if self.mxfp4_backend not in TRITON_BACKENDS:
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_scale)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_scale)
|
||||
else:
|
||||
layer.w13_weight = w13
|
||||
layer.w2_weight = w2
|
||||
self.w13_precision_config = w13_scale
|
||||
self.w2_precision_config = w2_scale
|
||||
|
||||
if w13_bias is not None and w2_bias is not None:
|
||||
replace_parameter(layer, "w13_bias", w13_bias)
|
||||
replace_parameter(layer, "w2_bias", w2_bias)
|
||||
|
||||
# Build quant config and kernel
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config is not None and self.experts_cls is not None:
|
||||
self.moe_kernel = make_mxfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
mxfp4_backend=self.mxfp4_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
# For w_mxfp4 with oracle backend, use oracle function
|
||||
if (
|
||||
self.ocp_mx_scheme == "w_mxfp4"
|
||||
and self.mxfp4_backend != Mxfp4MoeBackend.NONE
|
||||
):
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
if self.mxfp4_backend in TRITON_BACKENDS:
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
return make_mxfp4_moe_quant_config(
|
||||
mxfp4_backend=self.mxfp4_backend,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=getattr(layer, "w13_bias", None),
|
||||
w2_bias=getattr(layer, "w2_bias", None),
|
||||
)
|
||||
|
||||
# Existing code for other schemes
|
||||
# TODO(bowenbao): kept for emulation fallback, to be refactored into
|
||||
# dedicated emulation backend.
|
||||
if self.ocp_mx_scheme == "w_mxfp4":
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
@@ -1020,6 +1118,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
if self.moe_kernel is not None:
|
||||
return self.moe_kernel.is_monolithic
|
||||
return False
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1028,6 +1132,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# For w_mxfp4 with oracle kernel
|
||||
if self.moe_kernel is not None:
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
# Existing code for emulation/AITER paths
|
||||
if not self.emulate:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
@@ -1061,6 +1181,25 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
router_logits=router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user