[Models][Quantization] Add quantization configuration update in Voxtral model (#24122)

Signed-off-by: Alexandre Marques <almarque@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Alexandre Marques
2025-09-10 22:13:56 -04:00
committed by GitHub
parent cc99baf14d
commit 5931b7e5d9
2 changed files with 88 additions and 4 deletions

View File

@@ -23,6 +23,7 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -327,6 +328,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
super().__init__()
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
# update quant config to so that ignored module and target module names
# match the vLLM model names
if hasattr(vllm_config, "quant_config"):
vllm_config.quant_config = self.maybe_update_quant_config(
vllm_config.quant_config)
config = vllm_config.model_config.hf_config
self.config = config
self.downsample_factor = self.config.audio_config.downsample_factor
@@ -558,6 +565,72 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return loaded_weights
def maybe_update_quant_config(
self, quant_config: QuantizationConfig) -> QuantizationConfig:
"""
Update quant config to so that ignored module and target module names
match the vLLM model names.
Right now this is specific for compressed-tensors format and
load_format mistral.
"""
remapping_rules = [
(r"output", r"language_model.lm_head"),
(r"layers\.(\d+)\.attention\.wo",
r"language_model.model.layers.\1.self_attn.out_proj"),
(r"layers\.(\d+)\.attention\.w(.*)",
r"language_model.model.layers.\1.self_attn.\2_proj"),
(r"layers\.(\d+)\.feed_forward\.w1",
r"language_model.model.layers.\1.mlp.gate_proj"),
(r"layers\.(\d+)\.feed_forward\.w2",
r"language_model.model.layers.\1.mlp.down_proj"),
(r"layers\.(\d+)\.feed_forward\.w3",
r"language_model.model.layers.\1.mlp.up_proj"),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
r"whisper_encoder.whisper_encoder.conv1"),
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
r"whisper_encoder.whisper_encoder.conv2"),
(r"mm_whisper_embeddings\.audio_language_projection\.0",
r"audio_language_adapter.w_in"),
(r"mm_whisper_embeddings\.audio_language_projection\.2",
r"audio_language_adapter.w_out"),
]
# Update ignore list
if hasattr(quant_config, "ignore"):
mistral_ignore = []
for name in quant_config.ignore:
mistral_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
mistral_name = re.sub(pattern, repl, name)
mistral_ignore.append(mistral_name)
quant_config.ignore = mistral_ignore
# Update target list
if hasattr(quant_config, "config_groups"):
config_groups = quant_config.config_groups
for group_name in config_groups:
if "targets" in config_groups[group_name]:
targets = []
for name in config_groups[group_name]["targets"]:
mistral_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
mistral_name = re.sub(pattern, repl, name)
targets.append(mistral_name)
config_groups[group_name]["targets"] = targets
quant_config.config_groups = config_groups
return quant_config
class AudioLanguageAdapter(nn.Module):