[6/N] pass whole config to inner model (#10205)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-10 22:41:46 -08:00
committed by GitHub
parent f0f2e5638e
commit f89d18ff74
69 changed files with 681 additions and 963 deletions

View File

@@ -34,7 +34,7 @@ from transformers import PretrainedConfig
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -59,7 +59,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter
from .utils import is_pp_missing_parameter, maybe_prefix
_KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
@@ -390,7 +390,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
):
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
@@ -401,11 +400,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config,
cache_config,
quant_config,
prefix="llm")
self.vpm = self.init_vision_module(config, quant_config, prefix="vpm")
self.llm = self.init_llm(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "llm"))
self.vpm = self.init_vision_module(config,
quant_config,
prefix=maybe_prefix(prefix, "vpm"))
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
@@ -414,13 +413,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.resampler = self.init_resampler(self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix="resampler")
prefix=maybe_prefix(
prefix, "resampler"))
self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix="llm.lm_head")
prefix=maybe_prefix(
prefix, "llm.lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
@@ -661,9 +662,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
raise NotImplementedError
@@ -711,16 +710,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix),
name="model")
def init_vision_module(
@@ -875,15 +868,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix),
name="model")
def init_vision_module(
@@ -1022,16 +1010,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix),
name="model")
def init_vision_module(
@@ -1151,4 +1133,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(vllm_config, prefix=prefix)
return instance_class(vllm_config=vllm_config, prefix=prefix)