[Model] Enable LoRA support for Pixtral (#31724)
Signed-off-by: <> Signed-off-by: 赵策 <alcor@zhaocedeMacBook-Air.local> Signed-off-by: 赵策 <alcor@mac.mynetworksettings.com> Co-authored-by: 赵策 <alcor@mac.mynetworksettings.com>
This commit is contained in:
@@ -63,7 +63,13 @@ from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .module_mapping import MultiModelKeys
|
||||
from .utils import init_vllm_registered_model, maybe_prefix
|
||||
from .vision import (
|
||||
VisionEncoderInfo,
|
||||
@@ -365,7 +371,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
|
||||
info=PixtralProcessingInfo,
|
||||
dummy_inputs=PixtralDummyInputsBuilder,
|
||||
)
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class PixtralForConditionalGeneration(
|
||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
):
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
@@ -581,6 +589,25 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
# Now we call the language model load with the generator
|
||||
self.language_model.load_weights(llm_weights_generator())
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="vision_language_adapter",
|
||||
tower_model="vision_encoder",
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
|
||||
if getattr(self, "patch_merger", None) is None:
|
||||
return num_image_tokens
|
||||
merge_size = self.vision_args.spatial_merge_size
|
||||
return num_image_tokens * (merge_size**2)
|
||||
|
||||
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
|
||||
if getattr(self, "patch_merger", None) is None:
|
||||
return num_vision_tokens
|
||||
merge_size = self.vision_args.spatial_merge_size
|
||||
return num_vision_tokens // (merge_size**2)
|
||||
|
||||
|
||||
# Vision encoder
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user