[Bugfix] Fix MiniCPMV and Mllama BNB bug (#9917)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-11-04 11:36:41 +08:00
committed by GitHub
parent 91c9ebbb1b
commit c49f0407ba
4 changed files with 145 additions and 65 deletions

View File

@@ -131,16 +131,22 @@ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
class Resampler2_5(BaseResampler):
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
) -> None:
super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
def __init__(self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__(num_queries,
embed_dim,
num_heads,
kv_dim,
norm_layer,
quant_config=quant_config,
prefix=prefix)
self.max_size = max_size
self._set_2d_pos_cache(self.max_size)
@@ -404,7 +410,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
self.vpm.embeddings.embed_dim)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.resampler = self.init_resampler(self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix="resampler")
self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
@@ -666,7 +675,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> nn.Module:
raise NotImplementedError
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
raise NotImplementedError
def get_vision_embedding(
@@ -743,16 +756,21 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2(
embed_dim=embed_dim,
num_heads=embed_dim // 128,
grid_size=int(math.sqrt(self.config.query_num)),
kv_dim=vision_dim,
adaptive=False,
do_post_projection=True,
)
resampler = Resampler2(embed_dim=embed_dim,
num_heads=embed_dim // 128,
grid_size=int(
math.sqrt(self.config.query_num)),
kv_dim=vision_dim,
adaptive=False,
do_post_projection=True,
quant_config=quant_config,
prefix=prefix)
return resampler
@@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
# Currently, vllm does not support BNB quantization for the `out_proj`
# of the resampler, so it's necessary to distinguish between the
# vision encoder and the resampler's out_proj. The same applies to
# MiniCPMV2_6.
".self_attn.out_proj.", # vision encoder out_proj
# resampler
".kv_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
column_parallel_weights_modules = [
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
@@ -877,14 +907,18 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
)
resampler = Resampler2_5(num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler
def get_vision_embedding(
@@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
".self_attn.out_proj.",
# resampler
".kv_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
column_parallel_weights_modules = [
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
@@ -1019,15 +1061,19 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
)
resampler = Resampler2_5(num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler
def get_vision_embedding(