[Model][LoRA]LoRA support added for MiniCPMV2.6 (#8943)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -31,17 +31,15 @@ import torch
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.resampler import (Resampler2,
|
||||
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
||||
get_2d_sincos_pos_embed)
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
@@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
|
||||
class BaseResampler(nn.Module):
|
||||
"""
|
||||
A 2D perceiver-resampler network with one cross attention layers by
|
||||
(grid_size**2) learnable queries and 2d sincos pos_emb
|
||||
Outputs:
|
||||
A tensor with the shape of (grid_size**2, embed_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_queries: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_queries = num_queries
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||
trunc_normal_(self.query, std=0.02)
|
||||
if kv_dim is not None and kv_dim != embed_dim:
|
||||
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
|
||||
else:
|
||||
# Maintain the same return value with ReplicatedLinear.forward
|
||||
self.kv_proj = lambda *args, **kwargs: (
|
||||
nn.Identity()(*args, **kwargs),
|
||||
None,
|
||||
)
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.ln_q = norm_layer(embed_dim)
|
||||
self.ln_kv = norm_layer(embed_dim)
|
||||
self.ln_post = norm_layer(embed_dim)
|
||||
self.proj = nn.Parameter(
|
||||
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
|
||||
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def _repeat(self, query, N: int):
|
||||
return query.unsqueeze(1).repeat(1, N, 1)
|
||||
|
||||
|
||||
class Resampler2_5(BaseResampler):
|
||||
|
||||
def __init__(
|
||||
@@ -869,7 +815,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
return "resampler" in name
|
||||
|
||||
|
||||
class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# vision encoder
|
||||
"fc1",
|
||||
"fc2",
|
||||
"out_proj",
|
||||
# language model
|
||||
"qkv_proj", # same name with vision encoder
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -894,15 +868,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
# A custom version of SiglipVisionTransformer, won't work with TP
|
||||
from vllm.model_executor.models.na_vit import SiglipVisionTransformer
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
self.config.vision_config._attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
# not support sdpa
|
||||
self.config.vision_config._attn_implementation = "eager"
|
||||
model = SiglipVisionTransformer(self.config.vision_config)
|
||||
model = Idefics2VisionTransformer(self.config.vision_config)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
@@ -928,7 +895,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
pixel_values,
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
tgt_sizes=tgt_sizes,
|
||||
).last_hidden_state
|
||||
)
|
||||
return vision_embedding
|
||||
|
||||
def get_vision_hidden_states(
|
||||
@@ -960,12 +927,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
all_pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
tgt_sizes=tgt_sizes,
|
||||
).last_hidden_state
|
||||
)
|
||||
|
||||
return self.resampler(vision_embedding, tgt_sizes)
|
||||
|
||||
def is_default_weight_loading(self, name: str) -> bool:
|
||||
return "resampler" in name or "vpm" in name
|
||||
return "resampler" in name
|
||||
|
||||
|
||||
_SUPPORT_VERSION = {
|
||||
|
||||
Reference in New Issue
Block a user