[Misc] Gemma3ForConditionalGeneration supports LoRA (#14797)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||||
@@ -23,7 +24,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
|
SupportsMultiModal, SupportsPP)
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
@@ -371,8 +373,8 @@ class Gemma3MultiModalProjector(nn.Module):
|
|||||||
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
|
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
|
||||||
info=Gemma3ProcessingInfo,
|
info=Gemma3ProcessingInfo,
|
||||||
dummy_inputs=Gemma3DummyInputsBuilder)
|
dummy_inputs=Gemma3DummyInputsBuilder)
|
||||||
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||||
SupportsPP):
|
SupportsLoRA):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@@ -614,3 +616,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
"""
|
||||||
|
Get the module prefix in multimodal models
|
||||||
|
"""
|
||||||
|
return MultiModelKeys.from_string_field(
|
||||||
|
language_model="language_model",
|
||||||
|
connector="multi_modal_projector",
|
||||||
|
tower_model="vision_tower")
|
||||||
|
|||||||
Reference in New Issue
Block a user