[Feature] Add LoRA support for Gemma3 vision components (#32764)

This commit is contained in:
VihaanThat
2026-01-26 19:26:40 +05:30
committed by GitHub
parent 9ac818a551
commit 208c56256f

View File

@@ -656,3 +656,41 @@ class Gemma3ForConditionalGeneration(
connector="multi_modal_projector",
tower_model="vision_tower",
)
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
"""
Calculate the number of tokens output by the vision encoder.
The vision encoder processes images into patch embeddings. For Gemma3,
the relationship between prompt placeholder tokens and actual vision
encoder output tokens depends on the patch grid size.
Args:
num_image_tokens: Number of image placeholder tokens in the prompt
(typically mm_tokens_per_image per image)
Returns:
Number of tokens output by the vision encoder
"""
# For Gemma3, the vision encoder outputs tokens_per_side x tokens_per_side
# tokens per image. Since num_image_tokens represents the number of
# connector output tokens (mm_tokens_per_image = 256), and tokens_per_side
# is sqrt(256) = 16, we need to account for the token expansion.
# Based on empirical testing, the multiplier of 16 works correctly.
return num_image_tokens * 16
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
"""
Calculate the number of tokens output by the multimodal connector.
The connector applies projection and normalization but maintains the
token count for Gemma3.
Args:
num_vision_tokens: Number of tokens from vision encoder
Returns:
Number of tokens after connector processing
"""
# The Gemma3 connector maintains a 1:1 token mapping
return num_vision_tokens