[ROCm][Bugfix] fix exception related to trust_remote_code for MiniMax-M2.1-MXFP4 (#37698)
Signed-off-by: Hongxia Yang <hongxiay.yang@amd.com> Co-authored-by: Hongxia Yang <hongxiay.yang@amd.com>
This commit is contained in:
63
tests/quantization/test_quark_maybe_update_config.py
Normal file
63
tests/quantization/test_quark_maybe_update_config.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for QuarkConfig.maybe_update_config.
|
||||
|
||||
Fetches real HF configs (metadata only, no model weights) to verify
|
||||
that dynamic_mxfp4_quant is only enabled for DeepSeek-V3-family models.
|
||||
|
||||
Run: pytest tests/quantization/test_quark_maybe_update_config.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
|
||||
|
||||
|
||||
def _make_quark_config() -> QuarkConfig:
|
||||
"""Create a minimal QuarkConfig for testing."""
|
||||
return QuarkConfig(quant_config={}, kv_cache_group=[], pack_method="reorder")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-deepseek models must not flip dynamic_mxfp4_quant
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
["amd/MiniMax-M2.1-MXFP4"],
|
||||
)
|
||||
def test_non_deepseek_model_stays_false(model_name: str):
|
||||
"""Non-deepseek_v3 models must not enable dynamic_mxfp4_quant."""
|
||||
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
qcfg = _make_quark_config()
|
||||
|
||||
qcfg.maybe_update_config(model_name, hf_config=hf_config)
|
||||
|
||||
assert qcfg.dynamic_mxfp4_quant is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeepSeek-V3 family + fp4 must enable dynamic_mxfp4_quant
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
["amd/DeepSeek-R1-MXFP4-ASQ"],
|
||||
)
|
||||
def test_deepseek_family_fp4_enables_flag(model_name: str):
|
||||
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
qcfg = _make_quark_config()
|
||||
|
||||
qcfg.maybe_update_config(model_name, hf_config=hf_config)
|
||||
|
||||
assert qcfg.dynamic_mxfp4_quant is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Missing hf_config → warn and stay False
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_missing_hf_config_stays_false():
|
||||
qcfg = _make_quark_config()
|
||||
|
||||
qcfg.maybe_update_config("some/model")
|
||||
|
||||
assert qcfg.dynamic_mxfp4_quant is False
|
||||
@@ -526,7 +526,10 @@ class VllmConfig:
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}"
|
||||
)
|
||||
quant_config.maybe_update_config(model_config.model)
|
||||
quant_config.maybe_update_config(
|
||||
model_config.model,
|
||||
hf_config=model_config.hf_config,
|
||||
)
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
@@ -146,7 +147,12 @@ class AWQConfig(QuantizationConfig):
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
def maybe_update_config(
|
||||
self,
|
||||
model_name: str,
|
||||
hf_config: PretrainedConfig | None = None,
|
||||
revision: str | None = None,
|
||||
):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn import Parameter
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -332,7 +333,12 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
def maybe_update_config(
|
||||
self,
|
||||
model_name: str,
|
||||
hf_config: PretrainedConfig | None = None,
|
||||
revision: str | None = None,
|
||||
):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -168,10 +169,23 @@ class QuantizationConfig(ABC):
|
||||
# TODO (@kylesayrs): add implementations for all subclasses
|
||||
pass
|
||||
|
||||
def maybe_update_config(self, model_name: str): # noqa: B027
|
||||
def maybe_update_config( # noqa: B027
|
||||
self,
|
||||
model_name: str,
|
||||
hf_config: PretrainedConfig | None = None,
|
||||
revision: str | None = None,
|
||||
):
|
||||
"""
|
||||
Interface to update values after config initialization.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
hf_config: The Hugging Face config of the model
|
||||
revision: The revision of the model
|
||||
Returns:
|
||||
"""
|
||||
# TODO: revision is never passed currently in vllm.py,
|
||||
# but is used in subclasses, should we remove this parameter?
|
||||
pass
|
||||
|
||||
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm._custom_ops import (
|
||||
cpu_gemm_wna16,
|
||||
@@ -133,7 +134,12 @@ class CPUAWQConfig(QuantizationConfig):
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
def maybe_update_config(
|
||||
self,
|
||||
model_name: str,
|
||||
hf_config: PretrainedConfig | None = None,
|
||||
revision: str | None = None,
|
||||
):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
@@ -193,7 +194,12 @@ class GPTQConfig(QuantizationConfig):
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
def maybe_update_config(
|
||||
self,
|
||||
model_name: str,
|
||||
hf_config: PretrainedConfig | None = None,
|
||||
revision: str | None = None,
|
||||
):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -299,7 +300,12 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
def maybe_update_config(
|
||||
self,
|
||||
model_name: str,
|
||||
hf_config: PretrainedConfig | None = None,
|
||||
revision: str | None = None,
|
||||
):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
|
||||
@@ -5,6 +5,7 @@ import fnmatch
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
@@ -36,7 +37,6 @@ from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -45,6 +45,10 @@ __all__ = ["QuarkLinearMethod"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# model_type values that use dynamic MXFP4 re-quantization for
|
||||
# OCP MX fp4 Quark checkpoints
|
||||
_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"})
|
||||
|
||||
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
def __init__(
|
||||
@@ -63,19 +67,28 @@ class QuarkConfig(QuantizationConfig):
|
||||
self.pack_method = pack_method
|
||||
self.dynamic_mxfp4_quant = False
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
self.hf_config = get_config(
|
||||
model=model_name,
|
||||
trust_remote_code=False, # or get from model_config if available
|
||||
revision=revision,
|
||||
config_format="auto",
|
||||
)
|
||||
def maybe_update_config(
|
||||
self,
|
||||
model_name: str,
|
||||
hf_config: PretrainedConfig | None = None,
|
||||
revision: str | None = None,
|
||||
):
|
||||
"""Enable dynamic MXFP4 only for DeepSeek-V3-family + fp4 Quark checkpoints."""
|
||||
|
||||
quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if (
|
||||
getattr(hf_config, "model_type", None)
|
||||
not in _DEEPSEEK_V3_FAMILY_MODEL_TYPES
|
||||
):
|
||||
return
|
||||
|
||||
quant_config = getattr(hf_config, "quantization_config", None)
|
||||
if quant_config is not None:
|
||||
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
|
||||
model_type = self.hf_config.model_type
|
||||
if quant_dtype == "fp4" and model_type == "deepseek_v3":
|
||||
quant_dtype = (
|
||||
quant_config.get("global_quant_config", {})
|
||||
.get("weight", {})
|
||||
.get("dtype")
|
||||
)
|
||||
if quant_dtype == "fp4":
|
||||
self.dynamic_mxfp4_quant = True
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
|
||||
Reference in New Issue
Block a user