[Model][LoRA]LoRA support added for MiniCPMV2.5 (#7199)
This commit is contained in:
@@ -36,7 +36,7 @@ from transformers import PretrainedConfig
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -50,7 +50,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.models.utils import LLMWrapper
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@@ -59,10 +61,10 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"llm.lm_head": "lm_head",
|
||||
"llm.model": "llm",
|
||||
}
|
||||
|
||||
|
||||
@@ -621,6 +623,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(language_model="llm",
|
||||
connector="resampler",
|
||||
tower_model="vpm")
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -669,9 +679,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
return MiniCPMModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
return LLMWrapper(MiniCPMModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
# TODO :refactor this vision model
|
||||
@@ -697,6 +709,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
return model
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_tokens(input_ids)
|
||||
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
resampler = Resampler2(
|
||||
@@ -743,7 +758,34 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
return "resampler" in name or "vpm" in name
|
||||
|
||||
|
||||
class MiniCPMV2_5(MiniCPMVBaseModel):
|
||||
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# vision encoder
|
||||
"fc1",
|
||||
"fc2",
|
||||
"out_proj",
|
||||
# language model
|
||||
"qkv_proj", # same name with vision encoder
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -751,6 +793,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__(config, multimodal_config, cache_config, quant_config)
|
||||
assert self.version == (2, 5)
|
||||
@@ -761,9 +804,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
return LlamaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
return LLMWrapper(LlamaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
model = Idefics2VisionTransformer(self.config.vision_config)
|
||||
@@ -843,9 +887,11 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
return Qwen2Model(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
return LLMWrapper(Qwen2Model(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
# A custom version of SiglipVisionTransformer, won't work with TP
|
||||
@@ -870,7 +916,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_dim,
|
||||
)
|
||||
|
||||
return resampler
|
||||
|
||||
def get_vision_embedding(
|
||||
@@ -934,20 +979,25 @@ _SUPPORT_VERSION = {
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
|
||||
class MiniCPMV(MiniCPMVBaseModel):
|
||||
class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
|
||||
"""
|
||||
Different versions of MiniCPMV use different visual encoders and LLMs,
|
||||
which is not conducive to the current integration logic of LoRA and
|
||||
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
|
||||
"""
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __new__(cls,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None):
|
||||
if not hasattr(config, "version"):
|
||||
if config.hidden_size == 2304 and config.query_num == 64:
|
||||
version = (2, 0)
|
||||
|
||||
Reference in New Issue
Block a user