[Model] Enable LoRA support for tower and connector in LLaVA (#31513)
Signed-off-by: Jay Hemnani <jayhemnani9910@gmail.com> Co-authored-by: Jay Hemnani <jayhemnani9910@gmail.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -51,7 +51,13 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .module_mapping import MultiModelKeys
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (
|
||||
@@ -505,7 +511,9 @@ def init_vision_tower_for_llava(
|
||||
info=_build_llava_or_pixtral_hf_info,
|
||||
dummy_inputs=LlavaDummyInputsBuilder,
|
||||
)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class LlavaForConditionalGeneration(
|
||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@@ -734,6 +742,32 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector",
|
||||
tower_model="vision_tower",
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
# LLaVA's vision encoder outputs one token per patch without
|
||||
# spatial merging or pixel shuffle
|
||||
return num_image_tokens
|
||||
|
||||
def get_num_mm_connector_tokens(
|
||||
self,
|
||||
num_vision_tokens: int,
|
||||
) -> int:
|
||||
# LLaVA's MLP projector outputs the same number of tokens
|
||||
# as it receives from the vision encoder (1:1 mapping)
|
||||
return num_vision_tokens
|
||||
|
||||
|
||||
class MantisProcessingInfo(LlavaProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
|
||||
Reference in New Issue
Block a user