[Models] Add remaining model PP support (#7168)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
303d44790a
commit
0f6d7a9a34
@@ -1,12 +1,18 @@
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import supports_multimodal, supports_pp
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GENERATION_MODELS = {
|
||||
@@ -152,19 +158,25 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
class ModelRegistry:
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_model(model_arch: str):
|
||||
module_name, model_cls_name = _MODELS[model_arch]
|
||||
module = importlib.import_module(
|
||||
f"vllm.model_executor.models.{module_name}")
|
||||
return getattr(module, model_cls_name, None)
|
||||
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
|
||||
module_relname, cls_name = _MODELS[model_arch]
|
||||
return f"vllm.model_executor.models.{module_relname}", cls_name
|
||||
|
||||
@staticmethod
|
||||
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch in _OOT_MODELS:
|
||||
return _OOT_MODELS[model_arch]
|
||||
@lru_cache(maxsize=128)
|
||||
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch not in _MODELS:
|
||||
return None
|
||||
|
||||
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch in _OOT_MODELS:
|
||||
return _OOT_MODELS[model_arch]
|
||||
|
||||
if is_hip():
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
@@ -175,11 +187,24 @@ class ModelRegistry:
|
||||
"Model architecture %s is partially supported by ROCm: %s",
|
||||
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||
|
||||
return ModelRegistry._get_model(model_arch)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
model = ModelRegistry._try_get_model_stateless(model_arch)
|
||||
if model is not None:
|
||||
return model
|
||||
|
||||
return ModelRegistry._try_get_model_stateful(model_arch)
|
||||
|
||||
@staticmethod
|
||||
def resolve_model_cls(
|
||||
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
|
||||
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
@@ -200,21 +225,99 @@ class ModelRegistry:
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls.__name__)
|
||||
global _OOT_MODELS
|
||||
|
||||
_OOT_MODELS[model_arch] = model_cls
|
||||
|
||||
@staticmethod
|
||||
def is_embedding_model(model_arch: str) -> bool:
|
||||
return model_arch in _EMBEDDING_MODELS
|
||||
@lru_cache(maxsize=128)
|
||||
def _check_stateless(
|
||||
func: Callable[[Type[nn.Module]], bool],
|
||||
model_arch: str,
|
||||
*,
|
||||
default: Optional[bool] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Run a boolean function against a model and return the result.
|
||||
|
||||
If the model is not found, returns the provided default value.
|
||||
|
||||
If the model is not already imported, the function is run inside a
|
||||
subprocess to avoid initializing CUDA for the main program.
|
||||
"""
|
||||
model = ModelRegistry._try_get_model_stateless(model_arch)
|
||||
if model is not None:
|
||||
return func(model)
|
||||
|
||||
if model_arch not in _MODELS and default is not None:
|
||||
return default
|
||||
|
||||
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||
|
||||
valid_name_characters = string.ascii_letters + string.digits + "._"
|
||||
if any(s not in valid_name_characters for s in module_name):
|
||||
raise ValueError(f"Unsafe module name detected for {model_arch}")
|
||||
if any(s not in valid_name_characters for s in cls_name):
|
||||
raise ValueError(f"Unsafe class name detected for {model_arch}")
|
||||
if any(s not in valid_name_characters for s in func.__module__):
|
||||
raise ValueError(f"Unsafe module name detected for {func}")
|
||||
if any(s not in valid_name_characters for s in func.__name__):
|
||||
raise ValueError(f"Unsafe class name detected for {func}")
|
||||
|
||||
err_id = uuid.uuid4()
|
||||
|
||||
stmts = ";".join([
|
||||
f"from {module_name} import {cls_name}",
|
||||
f"from {func.__module__} import {func.__name__}",
|
||||
f"assert {func.__name__}({cls_name}), '{err_id}'",
|
||||
])
|
||||
|
||||
result = subprocess.run([sys.executable, "-c", stmts],
|
||||
capture_output=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
err_lines = [line.decode() for line in result.stderr.splitlines()]
|
||||
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
|
||||
err_str = "\n".join(err_lines)
|
||||
raise RuntimeError(
|
||||
"An unexpected error occurred while importing the model in "
|
||||
f"another process. Error log:\n{err_str}")
|
||||
|
||||
return result.returncode == 0
|
||||
|
||||
@staticmethod
|
||||
def is_multimodal_model(model_arch: str) -> bool:
|
||||
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
# TODO: find a way to avoid initializing CUDA prematurely to
|
||||
# use `supports_multimodal` to determine if a model is multimodal
|
||||
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
|
||||
# from vllm.model_executor.models.interfaces import supports_multimodal
|
||||
return model_arch in _MULTIMODAL_MODELS
|
||||
return any(arch in _EMBEDDING_MODELS for arch in architectures)
|
||||
|
||||
@staticmethod
|
||||
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
is_mm = partial(ModelRegistry._check_stateless,
|
||||
supports_multimodal,
|
||||
default=False)
|
||||
|
||||
return any(is_mm(arch) for arch in architectures)
|
||||
|
||||
@staticmethod
|
||||
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
is_pp = partial(ModelRegistry._check_stateless,
|
||||
supports_pp,
|
||||
default=False)
|
||||
|
||||
return any(is_pp(arch) for arch in architectures)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
Reference in New Issue
Block a user