[Bugfix] Fix unable to load some models (#10312)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-15 08:55:54 +08:00
committed by GitHub
parent 11cd1ae6ad
commit 972112d82f
13 changed files with 340 additions and 59 deletions

View File

@@ -382,11 +382,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
instantiated.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
quant_config = vllm_config.quant_config
@@ -699,12 +695,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
class MiniCPMV2_0(MiniCPMVBaseModel):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 0)
def init_llm(
@@ -857,12 +849,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 5)
def init_llm(
@@ -999,12 +987,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 6)
def init_llm(
@@ -1117,7 +1101,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []
def __new__(cls, vllm_config: VllmConfig, prefix: str = ""):
def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64: