[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations (#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
xuebwang-amd
2026-02-10 23:08:05 +08:00
committed by GitHub
parent 599e4335a4
commit b129136c7a
13 changed files with 1094 additions and 213 deletions

View File

@@ -386,6 +386,10 @@ class FusedMoEQuantConfig:
def use_nvfp4_w4a4(self) -> bool:
return self.quant_dtype == "nvfp4"
@property
def use_mxfp4_w4a8(self) -> bool:
return self._a1.dtype == "fp8" and self._w1.dtype == "mxfp4"
def config_name(self, dtype: torch.dtype) -> str | None:
"""
Return a string used to construct the filename that contains the
@@ -532,6 +536,8 @@ def fp8_w8a8_moe_quant_config(
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: list[int] | None = None,
@@ -549,6 +555,8 @@ def fp8_w8a8_moe_quant_config(
g1_alphas=g1_alphas,
w2_scale=w2_scale,
g2_alphas=g2_alphas,
w1_bias=w1_bias,
w2_bias=w2_bias,
a1_scale=a1_scale,
a1_gscale=a1_gscale,
a2_scale=a2_scale,
@@ -564,6 +572,8 @@ def int8_w8a8_moe_quant_config(
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
per_act_token_quant: bool = False,
) -> FusedMoEQuantConfig:
"""
@@ -575,6 +585,8 @@ def int8_w8a8_moe_quant_config(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=None,
@@ -654,6 +666,26 @@ def mxfp4_mxfp8_moe_quant_config(
)
def mxfp4_w4a8_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc("fp8", None, a1_scale, None, None, None),
_a2=FusedMoEQuantDesc("fp8", None, a2_scale, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)
def ocp_mx_moe_quant_config(
quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
@@ -691,6 +723,8 @@ def nvfp4_moe_quant_config(
a2_gscale: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and nvp4 weights.
@@ -699,6 +733,8 @@ def nvfp4_moe_quant_config(
"nvfp4",
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
g1_alphas=g1_alphas,

View File

@@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
@@ -1583,6 +1582,11 @@ def _get_config_quant_dtype(
return "mxfp6_e3m2"
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
return "mxfp6_e2m3"
elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
return torch.bfloat16
elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
return torch.float8_e4m3fn
return None
@@ -1617,17 +1621,10 @@ def fused_experts_impl(
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif ocp_mx_scheme is not None:
if ocp_mx_scheme in {
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
"w_mxfp4_a_mxfp6_e2m3",
}:
if ocp_mx_scheme.startswith("w_mxfp4"):
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme in {
"w_mxfp6_e3m2_a_mxfp6_e3m2",
"w_mxfp6_e2m3_a_mxfp6_e2m3",
}:
elif ocp_mx_scheme.startswith("w_mxfp6"):
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch"
)
@@ -1717,17 +1714,13 @@ def fused_experts_impl(
# TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done.
if ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
if ocp_mx_scheme.startswith("w_mxfp4"):
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
@@ -1736,7 +1729,7 @@ def fused_experts_impl(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
@@ -1779,6 +1772,7 @@ def fused_experts_impl(
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
)
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
@@ -1846,6 +1840,7 @@ def fused_experts_impl(
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
)
if expert_map is not None:

View File

@@ -221,12 +221,14 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
)
# TODO(rob): move this down to the kernel.
def maybe_roundup_hidden_size(
hidden_size: int,
act_dtype: torch.dtype,
quant_config: QuantizationConfig | None,
moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool,
model_type: str | None,
is_mxfp4_quant: bool,
) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
@@ -235,11 +237,12 @@ def maybe_roundup_hidden_size(
Args:
hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations.
quant_config: Fused MoE quantization configuration.
moe_parallel_config: Fused MoE parallelization strategy configuration.
is_lora_enabled: True if the engine is enabled with LoRA. This
is used in the case of mxfp4 quantization in selecting the
MxFP4Backend.
model_type: for checking if gpt-oss
is_mxfp4_quant: whether the layer is quantized with mxfp4
Return:
Rounded up hidden_size if rounding up is required based on the configs.
@@ -254,7 +257,7 @@ def maybe_roundup_hidden_size(
)
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
if model_type == "gpt_oss" and is_mxfp4_quant:
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
get_mxfp4_backend,
@@ -398,15 +401,6 @@ class FusedMoE(CustomOp):
# Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(
hidden_size,
moe_in_dtype,
quant_config,
self.moe_parallel_config,
is_lora_enabled=self.vllm_config.lora_config is not None,
)
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
@@ -508,7 +502,6 @@ class FusedMoE(CustomOp):
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
@@ -548,6 +541,24 @@ class FusedMoE(CustomOp):
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
hidden_size = maybe_roundup_hidden_size(
hidden_size=hidden_size,
act_dtype=moe_in_dtype,
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=(
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
),
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
)
self.hidden_size = hidden_size
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,

View File

@@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize,
)
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -241,7 +244,27 @@ def moe_kernel_quantize_input(
per_act_token_quant: bool,
block_shape: list[int] | None = None,
is_fp4_scale_swizzled: bool = True,
ocp_mx_scheme: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation
if ocp_mx_scheme is not None:
if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
pass # No QDQ needed for these schemes
elif ocp_mx_scheme.endswith("a_fp8"):
# Perform QDQ (quantize and dequantize) on activation for emulation
# purpose, because there is no native kernel for weight in ocp_mx_scheme
# and activation in FP8. The implementation is based on existing
# non-emulation ops.
qA, qA_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=False
)
A = per_tensor_dequantize(qA, qA_scale).to(A.dtype)
# After QDQ, we don't need further quantization
return A, None
# else: For other schemes (e.g., *_a_mxfp6_e3m2, *_a_mxfp6_e2m3),
# weights are already dequantized, and we proceed with normal
# activation quantization below.
if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:

View File

@@ -168,3 +168,19 @@ class QuantizationConfig(ABC):
Interface to update values after config initialization.
"""
pass
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""
Determine if mxfp4 quantization will be used for this config.
This allows hidden_size rounding to happen before moe_config creation
without needing to instantiate quant_method first.
Args:
prefix: The layer prefix/name in the model
layer: The layer module
Returns:
True if this config uses MXFP4 quantization, False otherwise
"""
return False

View File

@@ -229,10 +229,15 @@ class Mxfp4Config(QuantizationConfig):
)
return None
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""MXFP4 config always uses MXFP4 quantization."""
return True
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.weight_dtype = "mxfp4"
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None

View File

@@ -320,38 +320,45 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_ocp_mx(
self,
weight_quant: dict[str, Any] | None,
input_quant: dict[str, Any] | None,
def _is_w_ocp_mx_a_x(
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
"""
This check returns True only if it is an OCP-MX weight quantization.
The activation can be any data type (e.g., FP16/BF16, FP8, or OCP-MX format).
The rationale for checking only the weight type is that
the model loading concept and process primarily concerns the weights themselves.
"""
# Confirm weights quantized.
if weight_quant is None:
logger.debug(
"Quark model is not in OCP MX format: "
"weight_quant or input_quant not set"
"Quark model's weight quantization is incompatible with OCP_MX format: "
"weight_quant is not set."
)
return False
# Input and weight qscheme needs to be per group.
if (
weight_quant.get("qscheme") != "per_group"
or input_quant.get("qscheme") != "per_group"
):
logger.debug("Quark model is not in OCP MX format: not per_group")
if weight_quant.get("qscheme") != "per_group":
logger.debug(
"Quark model's weight quantization is incompatible with OCP MX format: "
"weight is not per_group."
)
return False
# Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
logger.debug("Quark model is not in OCP MX format: not group_size=32")
if weight_quant.get("group_size") != 32:
logger.debug(
"Quark model's weight quantization is incompatible with OCP MX format: "
"group_size of weight is not 32."
)
return False
# Activations and weight scales need to be in e8m0 format.
if (
weight_quant.get("scale_format") != "e8m0"
or input_quant.get("scale_format") != "e8m0"
):
logger.debug("Quark model is not in OCP MX format: not scale_format e8m0")
if weight_quant.get("scale_format") != "e8m0":
logger.debug(
"Quark model's weight quantization is incompatible with OCP MX format: "
"scale_format of weight is not e8m0."
)
return False
# Input and weight dtypes need to be any of fp4,
@@ -360,14 +367,31 @@ class QuarkConfig(QuantizationConfig):
"fp4",
"fp6_e3m2",
"fp6_e2m3",
} or input_quant.get("dtype") not in {"fp4", "fp6_e3m2", "fp6_e2m3"}:
}:
logger.debug(
"Quark model is not in OCP MX format: dtype not fp4, fp6_e3m2, fp6_e2m3"
"Quark model's weight quantization is incompatible with OCP MX format: "
"dtype is not in {fp4, fp6_e3m2, fp6_e2m3}."
)
return False
return True
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""
For Quark, determine if it's OCP MXFP4 by checking config directly.
This allows hidden_size rounding to happen before moe_config creation.
"""
layer_quant_config = self._find_matched_config(prefix, layer)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
return (
self._is_w_ocp_mx_a_x(weight_config, input_config)
and weight_config is not None
and weight_config.get("dtype") == "fp4"
and getattr(torch, "float4_e2m1fn_x2", None) is not None
)
def _find_matched_config(
self, layer_name: str, module: torch.nn.Module
) -> dict[str, Any]:
@@ -441,7 +465,7 @@ class QuarkConfig(QuantizationConfig):
is_static_input_scheme=True,
input_symmetric=input_config.get("symmetric"),
)
elif self._is_ocp_mx(weight_config, input_config):
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(weight_config, input_config)
raise NotImplementedError(

View File

@@ -8,6 +8,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
@@ -18,9 +19,15 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
mxfp4_w4a8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
get_mxfp4_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
@@ -37,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
@@ -46,6 +54,7 @@ __all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
class QuarkMoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.has_bias = self.moe.has_bias
@staticmethod
def get_moe_method(
@@ -67,7 +76,7 @@ class QuarkMoEMethod(FusedMoEMethodBase):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_ocp_mx(weight_config, input_config):
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
@@ -86,6 +95,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
self.weight_qscheme = self.weight_quant.get("qscheme")
self.input_qscheme = self.input_quant.get("qscheme")
self.weight_dtype = self.weight_quant.get("dtype", "").replace(
"fp8_e4m3", "fp8"
)
self.input_dtype = self.input_quant.get("dtype", "").replace("fp8_e4m3", "fp8")
per_tensor = (
self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor"
)
@@ -121,6 +134,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.model_type = getattr(
get_current_vllm_config().model_config.hf_config, "model_type", None
)
def create_weights(
self,
layer: torch.nn.Module,
@@ -166,9 +183,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
if self.weight_qscheme == "per_tensor":
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
if self.model_type != "gpt_oss":
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
else:
# For gpt_oss, the w1(gate) & w3(up) are fused as one.
# Therefore, only one weight scale for each expert.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 1, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
@@ -220,6 +244,27 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
if self.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
else:
layer.w13_bias, layer.w2_bias = None, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
@@ -278,21 +323,40 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
# For gpt_oss, w1 and w3 are fused into a single combined
# gate_up_proj tensor with size 2*intermediate_size_per_partition
# and only one scale per expert.
# Process the entire weight tensor as one shard.
if self.model_type == "gpt_oss":
for expert_id in range(layer.local_num_experts):
# Process all 2*intermediate_size_per_partition rows at once
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
layer.w13_weight[expert_id],
layer.w13_weight_scale[expert_id][0],
)
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
layer.w13_weight[expert_id], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size
else:
# For non-gpt_oss, process w1 and w3 shards separately
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
# quark's scale is 1 dim.
elif self.weight_qscheme == "per_channel":
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
@@ -343,6 +407,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
per_act_token_quant=self.input_qscheme == "per_channel",
per_out_ch_quant=self.weight_qscheme == "per_channel",
)
@@ -563,7 +629,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
input_config: dict[str, Any] | None,
moe: FusedMoEConfig,
):
super().__init__(moe)
@@ -571,35 +637,79 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.input_quant = input_config
weight_qscheme = self.weight_quant.get("qscheme")
input_qscheme = self.input_quant.get("qscheme")
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
if not weight_qscheme == "per_group":
raise ValueError(
"For MX(FP4) Fused MoE layers, only per-group scales "
"for weights and activations are supported. Found "
f"{weight_qscheme}, {input_qscheme}"
f"for weights are supported. Found {weight_qscheme}."
) # noqa E501
self.static_input_scales = not self.input_quant.get("is_dynamic")
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
if self.input_quant is not None:
input_quant = self.input_quant["dtype"]
if input_quant in ["fp4", "fp6_e3m2", "fp6_e2m3"]:
self.input_dtype = input_quant.replace("fp", "mxfp")
elif input_quant == "fp8_e4m3":
self.input_dtype = input_quant.replace("fp8_e4m3", "fp8")
else:
raise NotImplementedError(
f"Current input dtype {input_quant} is not compatible \
with OCP MX (weight) MoE quantization. Please open an issue"
)
else:
self.input_dtype = None
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
)
if self.static_input_scales:
if self.ocp_mx_scheme is None:
raise ValueError(
f"Unsupported OCP MX dtype combination for MoE: "
f"input_dtype={self.input_dtype}, weight_dtype={self.weight_dtype}. "
f"Please check that the combination is supported in OCP_MX_Scheme."
)
self.mxfp4_backend: Mxfp4Backend | None = None
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
if self.input_quant is not None:
self.static_input_scales = not self.input_quant.get("is_dynamic")
else:
self.static_input_scales = False
if any(
self.ocp_mx_scheme.endswith(a_scheme)
for a_scheme in ["a_mxfp4", "a_mxfp6_e3m2", "a_mxfp6_e2m3"]
):
if self.static_input_scales:
raise NotImplementedError(
"QuarkOCP_MX_MoEMethod with static input scales is currently "
f"not implemented for OCP MX scheme {self.ocp_mx_scheme}. "
"Please open an issue."
)
elif self.ocp_mx_scheme.endswith("a_fp8") and not self.static_input_scales:
raise NotImplementedError(
"QuarkOCP_MX_MoEMethod with static input scales is currently "
"not implemented. Please open an issue."
"QuarkOCP_MX_MoEMethod with dynamic input scales is currently "
f"not implemented for OCP MX scheme {self.ocp_mx_scheme}. "
"Please open an issue."
)
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
self.emulate = not current_platform.supports_mx() or not (
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
self.model_type = getattr(
get_current_vllm_config().model_config.hf_config, "model_type", None
)
self._emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
@@ -640,12 +750,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
params_dtype = torch.uint8
if self.model_type == "gpt_oss":
if current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256
)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64
)
else:
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
2 * intermediate_size_per_partition_after_pad,
self.get_packed_dim(hidden_size, self.weight_dtype),
dtype=params_dtype,
),
@@ -659,7 +780,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
torch.empty(
num_experts,
hidden_size,
self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype),
self.get_packed_dim(
intermediate_size_per_partition_after_pad, self.weight_dtype
),
dtype=params_dtype,
),
requires_grad=False,
@@ -672,7 +795,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
2 * intermediate_size_per_partition_after_pad,
hidden_size // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
@@ -682,7 +805,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
torch.ones(
num_experts,
hidden_size,
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
intermediate_size_per_partition_after_pad // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
@@ -693,8 +816,96 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if self.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
else:
layer.w13_bias, layer.w2_bias = None, None
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer):
if self.static_input_scales:
# firstly, process activations if fp8 static input
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
_, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fnuz),
torch.empty_like(
layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype
),
layer.w13_input_scale,
)
_, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fnuz),
torch.empty_like(
layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype
),
layer.w2_input_scale,
)
# Reset the parameter
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# secondly, process mxfp weights
if self.emulate:
torch.cuda.empty_cache()
return
from aiter.utility.fp4_utils import e8m0_shuffle
@@ -725,15 +936,40 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
)
if self.ocp_mx_scheme == "w_mxfp4":
return mxfp4_w4a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
)
elif self.ocp_mx_scheme == "w_mxfp4_a_fp8":
return mxfp4_w4a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
block_shape=None,
)
elif self.ocp_mx_scheme in ["w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"]:
raise NotImplementedError(
"Currently there is no corresponding fused moe quant config configured "
f"in vLLM for OCP MX scheme {self.ocp_mx_scheme}. Please open an issue."
)
else:
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
a1_scale=None,
a2_scale=None,
block_shape=None,
)
def apply(
self,
@@ -743,24 +979,34 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
if (
self.model_type == "gpt_oss"
and self.mxfp4_backend == Mxfp4Backend.TRITON
):
raise NotImplementedError(
"Triton kernel implemented fused MoE for GPT_OSS model "
"in Quark(MoE) format is not integrated or provided yet."
)
out = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
out = fused_experts(
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
@@ -773,5 +1019,3 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
return out

View File

@@ -20,26 +20,44 @@ SUPPORTED_OCP_MX_DTYPES = {"mxfp4", "mxfp6_e3m2", "mxfp6_e2m3"}
class OCP_MX_Scheme(str, Enum):
w_mxfp4 = "w_mxfp4"
w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4"
w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2"
w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3"
w_mxfp4_a_fp8 = "w_mxfp4_a_fp8"
w_mxfp6_e3m2 = "w_mxfp6_e3m2"
w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2"
w_mxfp6_e3m2_a_fp8 = "w_mxfp6_e3m2_a_fp8"
w_mxfp6_e2m3 = "w_mxfp6_e2m3"
w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3"
w_mxfp6_e2m3_a_fp8 = "w_mxfp6_e2m3_a_fp8"
@classmethod
def from_quant_dtype(cls, input_dtype: str | None, weight_dtype: str | None):
if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES:
if input_dtype not in OCP_MX_DTYPES and weight_dtype not in OCP_MX_DTYPES:
return None
elif input_dtype is None and weight_dtype == "mxfp4":
return cls.w_mxfp4
elif input_dtype is None and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2
elif input_dtype is None and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3
elif input_dtype == "mxfp4" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp4
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e3m2
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e2m3
elif input_dtype == "fp8" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_fp8
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2_a_mxfp6_e3m2
elif input_dtype == "fp8" and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2_a_fp8
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3_a_mxfp6_e2m3
elif input_dtype == "fp8" and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3_a_fp8
else:
logger.warning(
"input_dtype='%s' and"