[Model] Enable LoRA support for tower and connector in LLaVA (#31513)

Signed-off-by: Jay Hemnani <jayhemnani9910@gmail.com>
Co-authored-by: Jay Hemnani <jayhemnani9910@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Jay Hemnani
2026-01-01 19:32:39 -08:00
committed by GitHub
parent ea53ca5e85
commit 5ac55eb30f
2 changed files with 37 additions and 3 deletions

View File

@@ -51,7 +51,13 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .module_mapping import MultiModelKeys
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (
@@ -505,7 +511,9 @@ def init_vision_tower_for_llava(
info=_build_llava_or_pixtral_hf_info,
dummy_inputs=LlavaDummyInputsBuilder,
)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class LlavaForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -734,6 +742,32 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
# LLaVA's vision encoder outputs one token per patch without
# spatial merging or pixel shuffle
return num_image_tokens
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
# LLaVA's MLP projector outputs the same number of tokens
# as it receives from the vision encoder (1:1 mapping)
return num_vision_tokens
class MantisProcessingInfo(LlavaProcessingInfo):
def get_hf_processor(self, **kwargs: object):