[Model][1/N] Automatic conversion of CrossEncoding model (#20012)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-06-27 12:10:04 +08:00
committed by GitHub
parent e110930680
commit cd4cfee689
5 changed files with 239 additions and 167 deletions

View File

@@ -569,6 +569,10 @@ class ModelConfig:
else:
self.truncation_side = "right"
model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
@@ -660,8 +664,18 @@ class ModelConfig:
@property
def architectures(self) -> list[str]:
# architectures in the model config.
return getattr(self.hf_config, "architectures", [])
@property
def architecture(self) -> str:
# The architecture vllm actually used.
return self._architecture
@property
def model_info(self) -> dict[str, Any]:
return self._model_info
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""Pull model/tokenizer from S3 to temporary directory when needed.
@@ -4450,6 +4464,9 @@ class VllmConfig:
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.try_verify_and_update_config()
if self.model_config is not None:
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
@@ -4694,11 +4711,21 @@ class VllmConfig:
batch_size_capture_list)
def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config
model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len
self.compute_hash()
def try_verify_and_update_config(self):
architecture = getattr(self.model_config, "architecture", None)
if architecture is None:
return
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_config(self)
def __str__(self):
return (