[Model] Enable LoRA support for Pixtral (#31724)

Signed-off-by: <>
Signed-off-by: 赵策 <alcor@zhaocedeMacBook-Air.local>
Signed-off-by: 赵策 <alcor@mac.mynetworksettings.com>
Co-authored-by: 赵策 <alcor@mac.mynetworksettings.com>
This commit is contained in:
Ce Zhao
2026-01-08 08:00:57 -05:00
committed by GitHub
parent 03fd76c570
commit 1123a87892
2 changed files with 30 additions and 3 deletions

View File

@@ -63,7 +63,13 @@ from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .module_mapping import MultiModelKeys
from .utils import init_vllm_registered_model, maybe_prefix
from .vision import (
VisionEncoderInfo,
@@ -365,7 +371,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
info=PixtralProcessingInfo,
dummy_inputs=PixtralDummyInputsBuilder,
)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class PixtralForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
@@ -581,6 +589,25 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())
def get_mm_mapping(self) -> MultiModelKeys:
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="vision_language_adapter",
tower_model="vision_encoder",
)
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
if getattr(self, "patch_merger", None) is None:
return num_image_tokens
merge_size = self.vision_args.spatial_merge_size
return num_image_tokens * (merge_size**2)
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
if getattr(self, "patch_merger", None) is None:
return num_vision_tokens
merge_size = self.vision_args.spatial_merge_size
return num_vision_tokens // (merge_size**2)
# Vision encoder
@dataclass