[Core] Initialize LoRA support for tower and connector in multi-modal models (#26674)

Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Anexdeus <5142168@mail.ru>
This commit is contained in:
Jee Jee Li
2025-12-26 20:48:20 +08:00
committed by GitHub
parent 0b544e6476
commit ce1eafd1a5
20 changed files with 635 additions and 80 deletions

View File

@@ -714,3 +714,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
connector="model.connector",
tower_model="model.vision_model",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.config
scale_factor = hf_config.scale_factor
return num_image_tokens * scale_factor**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.config
scale_factor = hf_config.scale_factor
return num_vision_tokens // scale_factor**2

View File

@@ -136,6 +136,24 @@ class SupportsMultiModal(Protocol):
"""
...
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
"""
Implement this function to enable LoRA support
for the tower module of the multi-modal model.
Given the number of image tokens, output the number of
multi-modal encoder tokens.
"""
...
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
"""
Implement this function to enable LoRA support
for the connector module of the multi-modal model.
Given the number of vision tokens, output the number of
multi-modal connector tokens.
"""
...
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""

View File

@@ -1026,6 +1026,7 @@ class Qwen2_5_VLForConditionalGeneration(
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"qkv": ["qkv"], # For vision tower's already-packed QKV
}
# To ensure correct weight loading and mapping.
@@ -1568,6 +1569,25 @@ class Qwen2_5_VLForConditionalGeneration(
tower_model="visual.",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.config
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.config
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""

View File

@@ -1491,6 +1491,25 @@ class Qwen2VLForConditionalGeneration(
tower_model="visual.",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.config
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.config
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
pass

View File

@@ -1240,6 +1240,7 @@ class Qwen3VLForConditionalGeneration(
"gate_proj",
"up_proj",
],
"qkv": ["qkv"], # For vision tower's already-packed QKV
}
supports_encoder_tp_data = True
@@ -2087,10 +2088,29 @@ class Qwen3VLForConditionalGeneration(
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="visual.merger",
connector=["visual.merger", "visual.deepstack_merger_list"],
tower_model="visual.",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.config
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.config
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""