[Model] Enable LoRA support for tower and connector in LLaVA (#31513)
Signed-off-by: Jay Hemnani <jayhemnani9910@gmail.com> Co-authored-by: Jay Hemnani <jayhemnani9910@gmail.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -699,7 +699,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | ✅︎ | ✅︎ |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ |
|
||||
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ |
|
||||
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ |
|
||||
|
||||
@@ -51,7 +51,13 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .module_mapping import MultiModelKeys
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (
|
||||
@@ -505,7 +511,9 @@ def init_vision_tower_for_llava(
|
||||
info=_build_llava_or_pixtral_hf_info,
|
||||
dummy_inputs=LlavaDummyInputsBuilder,
|
||||
)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class LlavaForConditionalGeneration(
|
||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@@ -734,6 +742,32 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector",
|
||||
tower_model="vision_tower",
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
# LLaVA's vision encoder outputs one token per patch without
|
||||
# spatial merging or pixel shuffle
|
||||
return num_image_tokens
|
||||
|
||||
def get_num_mm_connector_tokens(
|
||||
self,
|
||||
num_vision_tokens: int,
|
||||
) -> int:
|
||||
# LLaVA's MLP projector outputs the same number of tokens
|
||||
# as it receives from the vision encoder (1:1 mapping)
|
||||
return num_vision_tokens
|
||||
|
||||
|
||||
class MantisProcessingInfo(LlavaProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
|
||||
Reference in New Issue
Block a user