[Models][Quantization] Add quantization configuration update in Voxtral model (#24122)
Signed-off-by: Alexandre Marques <almarque@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
cc99baf14d
commit
5931b7e5d9
@@ -23,6 +23,7 @@ from transformers.tokenization_utils_base import TextInput
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@@ -327,6 +328,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
super().__init__()
|
||||
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
|
||||
# update quant config to so that ignored module and target module names
|
||||
# match the vLLM model names
|
||||
if hasattr(vllm_config, "quant_config"):
|
||||
vllm_config.quant_config = self.maybe_update_quant_config(
|
||||
vllm_config.quant_config)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
self.downsample_factor = self.config.audio_config.downsample_factor
|
||||
@@ -558,6 +565,72 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return loaded_weights
|
||||
|
||||
def maybe_update_quant_config(
|
||||
self, quant_config: QuantizationConfig) -> QuantizationConfig:
|
||||
"""
|
||||
Update quant config to so that ignored module and target module names
|
||||
match the vLLM model names.
|
||||
Right now this is specific for compressed-tensors format and
|
||||
load_format mistral.
|
||||
"""
|
||||
remapping_rules = [
|
||||
(r"output", r"language_model.lm_head"),
|
||||
(r"layers\.(\d+)\.attention\.wo",
|
||||
r"language_model.model.layers.\1.self_attn.out_proj"),
|
||||
(r"layers\.(\d+)\.attention\.w(.*)",
|
||||
r"language_model.model.layers.\1.self_attn.\2_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w1",
|
||||
r"language_model.model.layers.\1.mlp.gate_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w2",
|
||||
r"language_model.model.layers.\1.mlp.down_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w3",
|
||||
r"language_model.model.layers.\1.mlp.up_proj"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
|
||||
),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
|
||||
),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
|
||||
r"whisper_encoder.whisper_encoder.conv1"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
|
||||
r"whisper_encoder.whisper_encoder.conv2"),
|
||||
(r"mm_whisper_embeddings\.audio_language_projection\.0",
|
||||
r"audio_language_adapter.w_in"),
|
||||
(r"mm_whisper_embeddings\.audio_language_projection\.2",
|
||||
r"audio_language_adapter.w_out"),
|
||||
]
|
||||
|
||||
# Update ignore list
|
||||
if hasattr(quant_config, "ignore"):
|
||||
mistral_ignore = []
|
||||
for name in quant_config.ignore:
|
||||
mistral_name = name
|
||||
for pattern, repl in remapping_rules:
|
||||
if re.fullmatch(pattern, name):
|
||||
mistral_name = re.sub(pattern, repl, name)
|
||||
mistral_ignore.append(mistral_name)
|
||||
quant_config.ignore = mistral_ignore
|
||||
|
||||
# Update target list
|
||||
if hasattr(quant_config, "config_groups"):
|
||||
config_groups = quant_config.config_groups
|
||||
for group_name in config_groups:
|
||||
if "targets" in config_groups[group_name]:
|
||||
targets = []
|
||||
for name in config_groups[group_name]["targets"]:
|
||||
mistral_name = name
|
||||
for pattern, repl in remapping_rules:
|
||||
if re.fullmatch(pattern, name):
|
||||
mistral_name = re.sub(pattern, repl, name)
|
||||
targets.append(mistral_name)
|
||||
config_groups[group_name]["targets"] = targets
|
||||
quant_config.config_groups = config_groups
|
||||
|
||||
return quant_config
|
||||
|
||||
|
||||
class AudioLanguageAdapter(nn.Module):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user