[ROCm][Bugfix] fix exception related to trust_remote_code for MiniMax-M2.1-MXFP4 (#37698)

Signed-off-by: Hongxia Yang <hongxiay.yang@amd.com>
Co-authored-by: Hongxia Yang <hongxiay.yang@amd.com>
This commit is contained in:
Hongxia Yang
2026-03-30 11:49:23 -04:00
committed by GitHub
parent e8b055a5ac
commit dbdd9ae067
9 changed files with 142 additions and 19 deletions

View File

@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for QuarkConfig.maybe_update_config.
Fetches real HF configs (metadata only, no model weights) to verify
that dynamic_mxfp4_quant is only enabled for DeepSeek-V3-family models.
Run: pytest tests/quantization/test_quark_maybe_update_config.py -v
"""
import pytest
from transformers import AutoConfig
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
def _make_quark_config() -> QuarkConfig:
"""Create a minimal QuarkConfig for testing."""
return QuarkConfig(quant_config={}, kv_cache_group=[], pack_method="reorder")
# ---------------------------------------------------------------------------
# Non-deepseek models must not flip dynamic_mxfp4_quant
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"model_name",
["amd/MiniMax-M2.1-MXFP4"],
)
def test_non_deepseek_model_stays_false(model_name: str):
"""Non-deepseek_v3 models must not enable dynamic_mxfp4_quant."""
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
qcfg = _make_quark_config()
qcfg.maybe_update_config(model_name, hf_config=hf_config)
assert qcfg.dynamic_mxfp4_quant is False
# ---------------------------------------------------------------------------
# DeepSeek-V3 family + fp4 must enable dynamic_mxfp4_quant
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"model_name",
["amd/DeepSeek-R1-MXFP4-ASQ"],
)
def test_deepseek_family_fp4_enables_flag(model_name: str):
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
qcfg = _make_quark_config()
qcfg.maybe_update_config(model_name, hf_config=hf_config)
assert qcfg.dynamic_mxfp4_quant is True
# ---------------------------------------------------------------------------
# Missing hf_config → warn and stay False
# ---------------------------------------------------------------------------
def test_missing_hf_config_stays_false():
qcfg = _make_quark_config()
qcfg.maybe_update_config("some/model")
assert qcfg.dynamic_mxfp4_quant is False

View File

@@ -526,7 +526,10 @@ class VllmConfig:
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}"
)
quant_config.maybe_update_config(model_config.model)
quant_config.maybe_update_config(
model_config.model,
hf_config=model_config.hf_config,
)
return quant_config
return None

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.logger import init_logger
@@ -146,7 +147,12 @@ class AWQConfig(QuantizationConfig):
self.modules_to_not_convert
)
def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert:
return

View File

@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn import Parameter
from transformers import PretrainedConfig
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
@@ -332,7 +333,12 @@ class AWQMarlinConfig(QuantizationConfig):
self.modules_to_not_convert
)
def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert:
return

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch
from torch import nn
from transformers import PretrainedConfig
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -168,10 +169,23 @@ class QuantizationConfig(ABC):
# TODO (@kylesayrs): add implementations for all subclasses
pass
def maybe_update_config(self, model_name: str): # noqa: B027
def maybe_update_config( # noqa: B027
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
"""
Interface to update values after config initialization.
Args:
model_name: The name of the model
hf_config: The Hugging Face config of the model
revision: The revision of the model
Returns:
"""
# TODO: revision is never passed currently in vllm.py,
# but is used in subclasses, should we remove this parameter?
pass
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:

View File

@@ -5,6 +5,7 @@ from typing import Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
from vllm._custom_ops import (
cpu_gemm_wna16,
@@ -133,7 +134,12 @@ class CPUAWQConfig(QuantizationConfig):
self.modules_to_not_convert
)
def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert:
return

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.logger import init_logger
@@ -193,7 +194,12 @@ class GPTQConfig(QuantizationConfig):
self.modules_in_block_to_quantize
)
def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]

View File

@@ -6,6 +6,7 @@ from typing import Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
@@ -299,7 +300,12 @@ class GPTQMarlinConfig(QuantizationConfig):
self.modules_in_block_to_quantize
)
def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]

View File

@@ -5,6 +5,7 @@ import fnmatch
from typing import TYPE_CHECKING, Any, cast
import torch
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
@@ -36,7 +37,6 @@ from vllm.model_executor.layers.quantization.quark.utils import (
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -45,6 +45,10 @@ __all__ = ["QuarkLinearMethod"]
logger = init_logger(__name__)
# model_type values that use dynamic MXFP4 re-quantization for
# OCP MX fp4 Quark checkpoints
_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"})
class QuarkConfig(QuantizationConfig):
def __init__(
@@ -63,19 +67,28 @@ class QuarkConfig(QuantizationConfig):
self.pack_method = pack_method
self.dynamic_mxfp4_quant = False
def maybe_update_config(self, model_name: str, revision: str | None = None):
self.hf_config = get_config(
model=model_name,
trust_remote_code=False, # or get from model_config if available
revision=revision,
config_format="auto",
)
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
"""Enable dynamic MXFP4 only for DeepSeek-V3-family + fp4 Quark checkpoints."""
quant_config = getattr(self.hf_config, "quantization_config", None)
if (
getattr(hf_config, "model_type", None)
not in _DEEPSEEK_V3_FAMILY_MODEL_TYPES
):
return
quant_config = getattr(hf_config, "quantization_config", None)
if quant_config is not None:
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
model_type = self.hf_config.model_type
if quant_dtype == "fp4" and model_type == "deepseek_v3":
quant_dtype = (
quant_config.get("global_quant_config", {})
.get("weight", {})
.get("dtype")
)
if quant_dtype == "fp4":
self.dynamic_mxfp4_quant = True
def get_linear_method(self) -> "QuarkLinearMethod":