[Core][VLM] Test registration for OOT multimodal models (#8717)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Roger Wang
2024-10-04 10:38:25 -07:00
committed by GitHub
parent e5dc713c23
commit 26aa325f4f
12 changed files with 227 additions and 49 deletions

View File

@@ -125,9 +125,10 @@ _MODELS = {
**_CONDITIONAL_GENERATION_MODELS,
}
# Architecture -> type.
# Architecture -> type or (module, class).
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
_OOT_MODELS_LAZY: Dict[str, Tuple[str, str]] = {}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
@@ -159,17 +160,24 @@ class ModelRegistry:
@staticmethod
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name
if model_arch in _MODELS:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name
if model_arch in _OOT_MODELS_LAZY:
return _OOT_MODELS_LAZY[model_arch]
raise KeyError(model_arch)
@staticmethod
@lru_cache(maxsize=128)
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
try:
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
except KeyError:
return None
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
module = importlib.import_module(module_name)
module = importlib.import_module(mod_name)
return getattr(module, cls_name, None)
@staticmethod
@@ -219,14 +227,35 @@ class ModelRegistry:
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
def register_model(model_arch: str, model_cls: Union[Type[nn.Module],
str]):
"""
Register an external model to be used in vLLM.
:code:`model_cls` can be either:
- A :class:`torch.nn.Module` class directly referencing the model.
- A string in the format :code:`<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if model_arch in _MODELS:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
model_cls)
_OOT_MODELS[model_arch] = model_cls
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
module_name, cls_name = split_str
_OOT_MODELS_LAZY[model_arch] = module_name, cls_name
else:
_OOT_MODELS[model_arch] = model_cls
@staticmethod
@lru_cache(maxsize=128)
@@ -248,13 +277,16 @@ class ModelRegistry:
if model is not None:
return func(model)
if model_arch not in _MODELS and default is not None:
return default
try:
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
except KeyError:
if default is not None:
return default
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
raise
valid_name_characters = string.ascii_letters + string.digits + "._"
if any(s not in valid_name_characters for s in module_name):
if any(s not in valid_name_characters for s in mod_name):
raise ValueError(f"Unsafe module name detected for {model_arch}")
if any(s not in valid_name_characters for s in cls_name):
raise ValueError(f"Unsafe class name detected for {model_arch}")
@@ -266,7 +298,7 @@ class ModelRegistry:
err_id = uuid.uuid4()
stmts = ";".join([
f"from {module_name} import {cls_name}",
f"from {mod_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])