[Core][VLM] Test registration for OOT multimodal models (#8717)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -125,9 +125,10 @@ _MODELS = {
|
||||
**_CONDITIONAL_GENERATION_MODELS,
|
||||
}
|
||||
|
||||
# Architecture -> type.
|
||||
# Architecture -> type or (module, class).
|
||||
# out of tree models
|
||||
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
||||
_OOT_MODELS_LAZY: Dict[str, Tuple[str, str]] = {}
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
||||
@@ -159,17 +160,24 @@ class ModelRegistry:
|
||||
|
||||
@staticmethod
|
||||
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
|
||||
module_relname, cls_name = _MODELS[model_arch]
|
||||
return f"vllm.model_executor.models.{module_relname}", cls_name
|
||||
if model_arch in _MODELS:
|
||||
module_relname, cls_name = _MODELS[model_arch]
|
||||
return f"vllm.model_executor.models.{module_relname}", cls_name
|
||||
|
||||
if model_arch in _OOT_MODELS_LAZY:
|
||||
return _OOT_MODELS_LAZY[model_arch]
|
||||
|
||||
raise KeyError(model_arch)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=128)
|
||||
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch not in _MODELS:
|
||||
try:
|
||||
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||
module = importlib.import_module(module_name)
|
||||
module = importlib.import_module(mod_name)
|
||||
return getattr(module, cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
@@ -219,14 +227,35 @@ class ModelRegistry:
|
||||
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
|
||||
|
||||
@staticmethod
|
||||
def register_model(model_arch: str, model_cls: Type[nn.Module]):
|
||||
def register_model(model_arch: str, model_cls: Union[Type[nn.Module],
|
||||
str]):
|
||||
"""
|
||||
Register an external model to be used in vLLM.
|
||||
|
||||
:code:`model_cls` can be either:
|
||||
|
||||
- A :class:`torch.nn.Module` class directly referencing the model.
|
||||
- A string in the format :code:`<module>:<class>` which can be used to
|
||||
lazily import the model. This is useful to avoid initializing CUDA
|
||||
when importing the model and thus the related error
|
||||
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
||||
"""
|
||||
if model_arch in _MODELS:
|
||||
logger.warning(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls.__name__)
|
||||
model_cls)
|
||||
|
||||
_OOT_MODELS[model_arch] = model_cls
|
||||
if isinstance(model_cls, str):
|
||||
split_str = model_cls.split(":")
|
||||
if len(split_str) != 2:
|
||||
msg = "Expected a string in the format `<module>:<class>`"
|
||||
raise ValueError(msg)
|
||||
|
||||
module_name, cls_name = split_str
|
||||
_OOT_MODELS_LAZY[model_arch] = module_name, cls_name
|
||||
else:
|
||||
_OOT_MODELS[model_arch] = model_cls
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=128)
|
||||
@@ -248,13 +277,16 @@ class ModelRegistry:
|
||||
if model is not None:
|
||||
return func(model)
|
||||
|
||||
if model_arch not in _MODELS and default is not None:
|
||||
return default
|
||||
try:
|
||||
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||
except KeyError:
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||
raise
|
||||
|
||||
valid_name_characters = string.ascii_letters + string.digits + "._"
|
||||
if any(s not in valid_name_characters for s in module_name):
|
||||
if any(s not in valid_name_characters for s in mod_name):
|
||||
raise ValueError(f"Unsafe module name detected for {model_arch}")
|
||||
if any(s not in valid_name_characters for s in cls_name):
|
||||
raise ValueError(f"Unsafe class name detected for {model_arch}")
|
||||
@@ -266,7 +298,7 @@ class ModelRegistry:
|
||||
err_id = uuid.uuid4()
|
||||
|
||||
stmts = ";".join([
|
||||
f"from {module_name} import {cls_name}",
|
||||
f"from {mod_name} import {cls_name}",
|
||||
f"from {func.__module__} import {func.__name__}",
|
||||
f"assert {func.__name__}({cls_name}), '{err_id}'",
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user