[LoRA]: Add lora support to qwen-2.5-omni (#24231)

This commit is contained in:
Yash Pratap Singh
2025-09-04 18:20:50 +05:30
committed by GitHub
parent 16ded21eeb
commit c9f7081f9c
2 changed files with 14 additions and 3 deletions

View File

@@ -41,6 +41,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
@@ -66,7 +67,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
@@ -705,7 +707,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
)
class Qwen2_5OmniThinkerForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP,
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
Qwen2_5OmniConditionalGenerationMixin):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@@ -798,6 +800,15 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_mm_mapping(self) -> MultiModelKeys:
"""Get module prefix for multimodal models to filter LoRA modules."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=[], # No explicit connector in this model
tower_model=["visual",
"audio_tower"], # Exclude vision and audio towers
)
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings: