[Bugfix] Handle ParallelLMHead in compressed-tensors get_quant_method (#37291)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Matthias Gehre
2026-03-30 16:30:52 +02:00
committed by GitHub
parent 246dc7d864
commit e8b055a5ac
2 changed files with 88 additions and 1 deletions

View File

@@ -5,13 +5,20 @@
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
from unittest.mock import Mock
import pytest
import torch
from compressed_tensors.quantization import QuantizationType
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig,
CompressedTensorsLinearMethod,
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A8Fp8,
@@ -26,6 +33,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -558,3 +566,72 @@ def test_w4a16_moe_torch_compile(vllm_runner):
) as llm:
output = llm.generate_greedy("Hi", max_tokens=1)
assert output
def _make_ct_config(*, target: str = "Linear") -> CompressedTensorsConfig:
"""Build a minimal CompressedTensorsConfig with INT8 channel quant."""
weight_quant = QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
symmetric=True,
dynamic=False,
)
return CompressedTensorsConfig(
target_scheme_map={
target: {
"weights": weight_quant,
"input_activations": None,
"format": "pack-quantized",
}
},
ignore=[],
quant_format="pack-quantized",
sparsity_scheme_map={},
sparsity_ignore_list=[],
)
def test_get_quant_method_returns_linear_method_for_parallel_lm_head():
"""ParallelLMHead whose name matches a target must get a quantised method."""
config = _make_ct_config(target="re:.*lm_head")
mock_lm_head = Mock(spec=ParallelLMHead)
mock_lm_head.__class__ = ParallelLMHead
method = config.get_quant_method(mock_lm_head, prefix="model.lm_head")
assert isinstance(method, CompressedTensorsLinearMethod), (
f"Expected CompressedTensorsLinearMethod, got {type(method).__name__}"
)
def test_get_quant_method_returns_none_for_ignored_parallel_lm_head():
"""ParallelLMHead on the ignore list should be left unquantized (None)."""
config = _make_ct_config(target="re:.*lm_head")
config.ignore = ["re:.*lm_head"]
mock_lm_head = Mock(spec=ParallelLMHead)
mock_lm_head.__class__ = ParallelLMHead
method = config.get_quant_method(mock_lm_head, prefix="model.lm_head")
assert method is None, (
f"Expected None for ignored ParallelLMHead, got {type(method).__name__}"
)
def test_get_quant_method_returns_none_for_unmatched_parallel_lm_head():
"""ParallelLMHead with target='Linear' (typical real model) must not crash.
Most compressed-tensors models only target 'Linear'. ParallelLMHead does
not match that target, so get_quant_method should return None (unquantized)
instead of raising ValueError.
"""
config = _make_ct_config(target="Linear")
mock_lm_head = Mock(spec=ParallelLMHead)
mock_lm_head.__class__ = ParallelLMHead
method = config.get_quant_method(mock_lm_head, prefix="model.lm_head")
assert method is None, (
f"Expected None for unmatched ParallelLMHead, got {type(method).__name__}"
)

View File

@@ -62,6 +62,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
if TYPE_CHECKING:
@@ -179,6 +180,15 @@ class CompressedTensorsConfig(QuantizationConfig):
else:
return quant_method
if isinstance(layer, ParallelLMHead):
try:
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
except ValueError:
quant_scheme = None
if quant_scheme is not None:
layer.scheme = quant_scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):