[5/N] pass the whole config to model (#9983)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-08 22:17:28 -08:00
committed by GitHub
parent 49d2a41a86
commit 1a95f10ee7
75 changed files with 583 additions and 654 deletions

View File

@@ -34,7 +34,7 @@ from transformers import PretrainedConfig
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -385,11 +385,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
@@ -701,12 +703,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(config, multimodal_config, cache_config, quant_config)
super().__init__(vllm_config)
assert self.version == (2, 0)
def init_llm(
@@ -867,13 +867,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(config, multimodal_config, cache_config, quant_config)
super().__init__(vllm_config)
assert self.version == (2, 5)
def init_llm(
@@ -1017,12 +1014,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(config, multimodal_config, cache_config, quant_config)
super().__init__(vllm_config)
assert self.version == (2, 6)
def init_llm(
@@ -1141,12 +1136,8 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []
def __new__(cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
def __new__(cls, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64:
version = (2, 0)
@@ -1160,5 +1151,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(config, multimodal_config, cache_config,
quant_config)
return instance_class(vllm_config, prefix=prefix)