[LoRA]: Add LoRA support to Mistral's Voxtral models (#24517)
Signed-off-by: Yash Pratap Singh <yashsingh20001@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
committed by
GitHub
parent
6cbd41909e
commit
9e3c3a7df2
@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||
# yapf: enable
|
||||
@@ -44,8 +45,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsTranscription)
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
@@ -313,9 +314,15 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
info=VoxtralProcessingInfo,
|
||||
dummy_inputs=VoxtralDummyInputsBuilder)
|
||||
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsTranscription):
|
||||
SupportsPP, SupportsLoRA,
|
||||
SupportsTranscription):
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
@@ -341,6 +348,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""Get module prefix for multimodal models to filter LoRA modules."""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="audio_language_adapter",
|
||||
tower_model=["whisper_encoder"],
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user