[6/N] pass whole config to inner model (#10205)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -34,7 +34,7 @@ from transformers import PretrainedConfig
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -59,7 +59,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import is_pp_missing_parameter
|
||||
from .utils import is_pp_missing_parameter, maybe_prefix
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"llm.lm_head": "lm_head",
|
||||
@@ -390,7 +390,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
):
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
super().__init__()
|
||||
# All MiniCPM-V models disable `tie_word_embeddings` but
|
||||
@@ -401,11 +400,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="llm")
|
||||
self.vpm = self.init_vision_module(config, quant_config, prefix="vpm")
|
||||
self.llm = self.init_llm(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "llm"))
|
||||
self.vpm = self.init_vision_module(config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vpm"))
|
||||
param_dtype = torch.get_default_dtype()
|
||||
self.vpm.to(dtype=param_dtype)
|
||||
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
|
||||
@@ -414,13 +413,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.resampler = self.init_resampler(self.embed_dim,
|
||||
self.vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix="resampler")
|
||||
prefix=maybe_prefix(
|
||||
prefix, "resampler"))
|
||||
self.resampler.to(device="cuda", dtype=param_dtype)
|
||||
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix="llm.lm_head")
|
||||
prefix=maybe_prefix(
|
||||
prefix, "llm.lm_head"))
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
@@ -661,9 +662,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
@@ -711,16 +710,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
|
||||
return LLMWrapper(MiniCPMModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(
|
||||
@@ -875,15 +868,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
return LLMWrapper(LlamaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(
|
||||
@@ -1022,16 +1010,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
|
||||
return LLMWrapper(Qwen2Model(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(
|
||||
@@ -1151,4 +1133,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
|
||||
if instance_class is None:
|
||||
raise ValueError(
|
||||
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
|
||||
return instance_class(vllm_config, prefix=prefix)
|
||||
return instance_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
Reference in New Issue
Block a user