[Models] Add remaining model PP support (#7168)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Murali Andoorveedu
2024-10-03 19:56:58 -07:00
committed by GitHub
parent 303d44790a
commit 0f6d7a9a34
69 changed files with 2585 additions and 1344 deletions

View File

@@ -1,12 +1,18 @@
import functools
import importlib
from typing import Dict, List, Optional, Tuple, Type
import string
import subprocess
import sys
import uuid
from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
from .interfaces import supports_multimodal, supports_pp
logger = init_logger(__name__)
_GENERATION_MODELS = {
@@ -152,19 +158,25 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class ModelRegistry:
@staticmethod
@functools.lru_cache(maxsize=128)
def _get_model(model_arch: str):
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name
@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
@lru_cache(maxsize=128)
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
return None
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
module = importlib.import_module(module_name)
return getattr(module, cls_name, None)
@staticmethod
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
@@ -175,11 +187,24 @@ class ModelRegistry:
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
return ModelRegistry._get_model(model_arch)
return None
@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return model
return ModelRegistry._try_get_model_stateful(model_arch)
@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
@@ -200,21 +225,99 @@ class ModelRegistry:
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
@staticmethod
def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS
@lru_cache(maxsize=128)
def _check_stateless(
func: Callable[[Type[nn.Module]], bool],
model_arch: str,
*,
default: Optional[bool] = None,
) -> bool:
"""
Run a boolean function against a model and return the result.
If the model is not found, returns the provided default value.
If the model is not already imported, the function is run inside a
subprocess to avoid initializing CUDA for the main program.
"""
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return func(model)
if model_arch not in _MODELS and default is not None:
return default
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
valid_name_characters = string.ascii_letters + string.digits + "._"
if any(s not in valid_name_characters for s in module_name):
raise ValueError(f"Unsafe module name detected for {model_arch}")
if any(s not in valid_name_characters for s in cls_name):
raise ValueError(f"Unsafe class name detected for {model_arch}")
if any(s not in valid_name_characters for s in func.__module__):
raise ValueError(f"Unsafe module name detected for {func}")
if any(s not in valid_name_characters for s in func.__name__):
raise ValueError(f"Unsafe class name detected for {func}")
err_id = uuid.uuid4()
stmts = ";".join([
f"from {module_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])
result = subprocess.run([sys.executable, "-c", stmts],
capture_output=True)
if result.returncode != 0:
err_lines = [line.decode() for line in result.stderr.splitlines()]
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
err_str = "\n".join(err_lines)
raise RuntimeError(
"An unexpected error occurred while importing the model in "
f"another process. Error log:\n{err_str}")
return result.returncode == 0
@staticmethod
def is_multimodal_model(model_arch: str) -> bool:
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
# TODO: find a way to avoid initializing CUDA prematurely to
# use `supports_multimodal` to determine if a model is multimodal
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
# from vllm.model_executor.models.interfaces import supports_multimodal
return model_arch in _MULTIMODAL_MODELS
return any(arch in _EMBEDDING_MODELS for arch in architectures)
@staticmethod
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_mm = partial(ModelRegistry._check_stateless,
supports_multimodal,
default=False)
return any(is_mm(arch) for arch in architectures)
@staticmethod
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_pp = partial(ModelRegistry._check_stateless,
supports_pp,
default=False)
return any(is_pp(arch) for arch in architectures)
__all__ = [