[5/N] pass the whole config to model (#9983)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -34,7 +34,7 @@ from transformers import PretrainedConfig
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -385,11 +385,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
super().__init__()
|
||||
# All MiniCPM-V models disable `tie_word_embeddings` but
|
||||
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
|
||||
@@ -701,12 +703,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(config, multimodal_config, cache_config, quant_config)
|
||||
super().__init__(vllm_config)
|
||||
assert self.version == (2, 0)
|
||||
|
||||
def init_llm(
|
||||
@@ -867,13 +867,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(config, multimodal_config, cache_config, quant_config)
|
||||
super().__init__(vllm_config)
|
||||
assert self.version == (2, 5)
|
||||
|
||||
def init_llm(
|
||||
@@ -1017,12 +1014,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(config, multimodal_config, cache_config, quant_config)
|
||||
super().__init__(vllm_config)
|
||||
assert self.version == (2, 6)
|
||||
|
||||
def init_llm(
|
||||
@@ -1141,12 +1136,8 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(cls,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None):
|
||||
def __new__(cls, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
if not hasattr(config, "version"):
|
||||
if config.hidden_size == 2304 and config.query_num == 64:
|
||||
version = (2, 0)
|
||||
@@ -1160,5 +1151,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
|
||||
if instance_class is None:
|
||||
raise ValueError(
|
||||
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
|
||||
return instance_class(config, multimodal_config, cache_config,
|
||||
quant_config)
|
||||
return instance_class(vllm_config, prefix=prefix)
|
||||
|
||||
Reference in New Issue
Block a user