[Bugfix] Fix MiniCPMV and Mllama BNB bug (#9917)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-11-04 11:36:41 +08:00
committed by GitHub
parent 91c9ebbb1b
commit c49f0407ba
4 changed files with 145 additions and 65 deletions

View File

@@ -41,6 +41,7 @@ from torch import nn
from torch.nn.init import trunc_normal_
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
@@ -154,15 +155,15 @@ class BaseResampler(nn.Module):
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
do_post_projection: bool = True,
) -> None:
def __init__(self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
do_post_projection: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.num_queries = num_queries
@@ -172,7 +173,11 @@ class BaseResampler(nn.Module):
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
self.kv_proj = ReplicatedLinear(kv_dim,
embed_dim,
bias=False,
quant_config=quant_config,
prefix=prefix)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
@@ -209,22 +214,24 @@ class Resampler2(BaseResampler):
present in minicpmv2.0, but not qwen-vl.
"""
def __init__(
self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
do_post_projection: bool = True,
) -> None:
def __init__(self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
do_post_projection: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__(grid_size**2,
embed_dim,
num_heads,
kv_dim,
norm_layer,
do_post_projection=do_post_projection)
do_post_projection=do_post_projection,
quant_config=quant_config,
prefix=prefix)
self.adaptive = adaptive
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,