Support LoRA for Mistral3 (#17428)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
@@ -31,7 +32,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
@@ -382,8 +384,8 @@ def init_vision_tower_for_llava(
|
||||
_build_mistral3_processor,
|
||||
info=_build_mistral3_info,
|
||||
dummy_inputs=Mistral3DummyInputsBuilder)
|
||||
class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@@ -594,3 +596,12 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector",
|
||||
tower_model="vision_tower")
|
||||
|
||||
Reference in New Issue
Block a user