[Bugfix] Fix MiniCPMV and Mllama BNB bug (#9917)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -131,16 +131,22 @@ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
class Resampler2_5(BaseResampler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_queries: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||
max_size: Tuple[int, int] = (70, 70),
|
||||
) -> None:
|
||||
super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
|
||||
def __init__(self,
|
||||
num_queries: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||
max_size: Tuple[int, int] = (70, 70),
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__(num_queries,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim,
|
||||
norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
self.max_size = max_size
|
||||
self._set_2d_pos_cache(self.max_size)
|
||||
@@ -404,7 +410,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
|
||||
self.vpm.embeddings.embed_dim)
|
||||
self.embed_dim = self.config.hidden_size
|
||||
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
||||
self.resampler = self.init_resampler(self.embed_dim,
|
||||
self.vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix="resampler")
|
||||
self.resampler.to(device="cuda", dtype=param_dtype)
|
||||
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
@@ -666,7 +675,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||
def init_resampler(self,
|
||||
embed_dim: int,
|
||||
vision_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vision_embedding(
|
||||
@@ -743,16 +756,21 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_tokens(input_ids)
|
||||
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||
def init_resampler(self,
|
||||
embed_dim: int,
|
||||
vision_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> nn.Module:
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
resampler = Resampler2(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
grid_size=int(math.sqrt(self.config.query_num)),
|
||||
kv_dim=vision_dim,
|
||||
adaptive=False,
|
||||
do_post_projection=True,
|
||||
)
|
||||
resampler = Resampler2(embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
grid_size=int(
|
||||
math.sqrt(self.config.query_num)),
|
||||
kv_dim=vision_dim,
|
||||
adaptive=False,
|
||||
do_post_projection=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
return resampler
|
||||
|
||||
@@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
# vision encoder
|
||||
".fc1.",
|
||||
".fc2.",
|
||||
# Currently, vllm does not support BNB quantization for the `out_proj`
|
||||
# of the resampler, so it's necessary to distinguish between the
|
||||
# vision encoder and the resampler's out_proj. The same applies to
|
||||
# MiniCPMV2_6.
|
||||
".self_attn.out_proj.", # vision encoder out_proj
|
||||
# resampler
|
||||
".kv_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
column_parallel_weights_modules = [
|
||||
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
@@ -877,14 +907,18 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||
def init_resampler(self,
|
||||
embed_dim: int,
|
||||
vision_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> nn.Module:
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
resampler = Resampler2_5(
|
||||
num_queries=self.config.query_num,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_dim,
|
||||
)
|
||||
resampler = Resampler2_5(num_queries=self.config.query_num,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
return resampler
|
||||
|
||||
def get_vision_embedding(
|
||||
@@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
# vision encoder
|
||||
".fc1.",
|
||||
".fc2.",
|
||||
".self_attn.out_proj.",
|
||||
# resampler
|
||||
".kv_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
column_parallel_weights_modules = [
|
||||
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
@@ -1019,15 +1061,19 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||
def init_resampler(self,
|
||||
embed_dim: int,
|
||||
vision_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> nn.Module:
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
# The resampler in 2.6 remains consistent with the one in 2.5.
|
||||
resampler = Resampler2_5(
|
||||
num_queries=self.config.query_num,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_dim,
|
||||
)
|
||||
resampler = Resampler2_5(num_queries=self.config.query_num,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
return resampler
|
||||
|
||||
def get_vision_embedding(
|
||||
|
||||
Reference in New Issue
Block a user