[Bugfix] Fix MiniCPMV and Mllama BNB bug (#9917)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -41,6 +41,7 @@ from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
@@ -154,15 +155,15 @@ class BaseResampler(nn.Module):
|
||||
A tensor with the shape of (grid_size**2, embed_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_queries: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||
do_post_projection: bool = True,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
num_queries: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||
do_post_projection: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_queries = num_queries
|
||||
@@ -172,7 +173,11 @@ class BaseResampler(nn.Module):
|
||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||
trunc_normal_(self.query, std=0.02)
|
||||
if kv_dim is not None and kv_dim != embed_dim:
|
||||
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
|
||||
self.kv_proj = ReplicatedLinear(kv_dim,
|
||||
embed_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
else:
|
||||
# Maintain the same return value with ReplicatedLinear.forward
|
||||
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
||||
@@ -209,22 +214,24 @@ class Resampler2(BaseResampler):
|
||||
present in minicpmv2.0, but not qwen-vl.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid_size: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||
adaptive: bool = False,
|
||||
do_post_projection: bool = True,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
grid_size: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: Optional[int] = None,
|
||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||
adaptive: bool = False,
|
||||
do_post_projection: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__(grid_size**2,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim,
|
||||
norm_layer,
|
||||
do_post_projection=do_post_projection)
|
||||
do_post_projection=do_post_projection,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
self.adaptive = adaptive
|
||||
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
||||
|
||||
Reference in New Issue
Block a user