Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
Whenever you add an architecture to this page, please also update
|
||||
`tests/models/registry.py` with example HuggingFace models for it.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import importlib
|
||||
import json
|
||||
@@ -23,21 +24,33 @@ import torch.nn as nn
|
||||
import transformers
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import (ModelConfig, iter_architecture_defaults,
|
||||
try_match_architecture_defaults)
|
||||
from vllm.config import (
|
||||
ModelConfig,
|
||||
iter_architecture_defaults,
|
||||
try_match_architecture_defaults,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import logtime
|
||||
from vllm.transformers_utils.dynamic_module import (
|
||||
try_get_class_from_dynamic_module)
|
||||
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
||||
|
||||
from .interfaces import (has_inner_state, has_noops, is_attention_free,
|
||||
is_hybrid, supports_cross_encoding,
|
||||
supports_multimodal,
|
||||
supports_multimodal_encoder_tp_data,
|
||||
supports_multimodal_raw_input_only, supports_pp,
|
||||
supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
|
||||
is_text_generation_model)
|
||||
from .interfaces import (
|
||||
has_inner_state,
|
||||
has_noops,
|
||||
is_attention_free,
|
||||
is_hybrid,
|
||||
supports_cross_encoding,
|
||||
supports_multimodal,
|
||||
supports_multimodal_encoder_tp_data,
|
||||
supports_multimodal_raw_input_only,
|
||||
supports_pp,
|
||||
supports_transcription,
|
||||
supports_v0_only,
|
||||
)
|
||||
from .interfaces_base import (
|
||||
get_default_pooling_type,
|
||||
is_pooling_model,
|
||||
is_text_generation_model,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -333,9 +346,7 @@ _VLLM_MODELS = {
|
||||
# can modify this variable to alter the args if needed. e.g.
|
||||
# when we use par format to pack things together, sys.executable
|
||||
# might not be the target we want to run.
|
||||
_SUBPROCESS_COMMAND = [
|
||||
sys.executable, "-m", "vllm.model_executor.models.registry"
|
||||
]
|
||||
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
|
||||
|
||||
_PREVIOUSLY_SUPPORTED_MODELS = {
|
||||
"MotifForCausalLM": "0.10.2",
|
||||
@@ -380,24 +391,26 @@ class _ModelInfo:
|
||||
default_pooling_type=get_default_pooling_type(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_multimodal_raw_input_only=
|
||||
supports_multimodal_raw_input_only(model),
|
||||
supports_multimodal_encoder_tp_data=
|
||||
supports_multimodal_encoder_tp_data(model),
|
||||
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
||||
model
|
||||
),
|
||||
supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
|
||||
model
|
||||
),
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
is_hybrid=is_hybrid(model),
|
||||
supports_transcription=supports_transcription(model),
|
||||
supports_transcription_only=(supports_transcription(model) and
|
||||
model.supports_transcription_only),
|
||||
supports_transcription_only=(
|
||||
supports_transcription(model) and model.supports_transcription_only
|
||||
),
|
||||
supports_v0_only=supports_v0_only(model),
|
||||
has_noops=has_noops(model),
|
||||
)
|
||||
|
||||
|
||||
class _BaseRegisteredModel(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def inspect_model_cls(self) -> _ModelInfo:
|
||||
raise NotImplementedError
|
||||
@@ -435,6 +448,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
"""
|
||||
Represents a model that has not been imported in the main process.
|
||||
"""
|
||||
|
||||
module_name: str
|
||||
class_name: str
|
||||
|
||||
@@ -446,38 +460,42 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
|
||||
return f"{cls_name}.json"
|
||||
|
||||
def _load_modelinfo_from_cache(self,
|
||||
module_hash: str) -> _ModelInfo | None:
|
||||
def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
|
||||
try:
|
||||
try:
|
||||
modelinfo_path = self._get_cache_dir(
|
||||
) / self._get_cache_filename()
|
||||
modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
|
||||
with open(modelinfo_path, encoding="utf-8") as file:
|
||||
mi_dict = json.load(file)
|
||||
except FileNotFoundError:
|
||||
logger.debug(("Cached model info file "
|
||||
"for class %s.%s not found"), self.module_name,
|
||||
self.class_name)
|
||||
logger.debug(
|
||||
("Cached model info file for class %s.%s not found"),
|
||||
self.module_name,
|
||||
self.class_name,
|
||||
)
|
||||
return None
|
||||
|
||||
if mi_dict["hash"] != module_hash:
|
||||
logger.debug(("Cached model info file "
|
||||
"for class %s.%s is stale"), self.module_name,
|
||||
self.class_name)
|
||||
logger.debug(
|
||||
("Cached model info file for class %s.%s is stale"),
|
||||
self.module_name,
|
||||
self.class_name,
|
||||
)
|
||||
return None
|
||||
|
||||
# file not changed, use cached _ModelInfo properties
|
||||
return _ModelInfo(**mi_dict["modelinfo"])
|
||||
except Exception:
|
||||
logger.exception(("Cached model info "
|
||||
"for class %s.%s error. "), self.module_name,
|
||||
self.class_name)
|
||||
logger.exception(
|
||||
("Cached model info for class %s.%s error. "),
|
||||
self.module_name,
|
||||
self.class_name,
|
||||
)
|
||||
return None
|
||||
|
||||
def _save_modelinfo_to_cache(self, mi: _ModelInfo,
|
||||
module_hash: str) -> None:
|
||||
def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
|
||||
"""save dictionary json file to cache"""
|
||||
from vllm.model_executor.model_loader.weight_utils import atomic_writer
|
||||
|
||||
try:
|
||||
modelinfo_dict = {
|
||||
"hash": module_hash,
|
||||
@@ -486,15 +504,14 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
cache_dir = self._get_cache_dir()
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
modelinfo_path = cache_dir / self._get_cache_filename()
|
||||
with atomic_writer(modelinfo_path, encoding='utf-8') as f:
|
||||
with atomic_writer(modelinfo_path, encoding="utf-8") as f:
|
||||
json.dump(modelinfo_dict, f, indent=2)
|
||||
except Exception:
|
||||
logger.exception("Error saving model info cache.")
|
||||
|
||||
@logtime(logger=logger, msg="Registry inspect model class")
|
||||
def inspect_model_cls(self) -> _ModelInfo:
|
||||
model_path = Path(
|
||||
__file__).parent / f"{self.module_name.split('.')[-1]}.py"
|
||||
model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
|
||||
module_hash = None
|
||||
|
||||
if model_path.exists():
|
||||
@@ -503,21 +520,26 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
|
||||
mi = self._load_modelinfo_from_cache(module_hash)
|
||||
if mi is not None:
|
||||
logger.debug(("Loaded model info "
|
||||
"for class %s.%s from cache"), self.module_name,
|
||||
self.class_name)
|
||||
logger.debug(
|
||||
("Loaded model info for class %s.%s from cache"),
|
||||
self.module_name,
|
||||
self.class_name,
|
||||
)
|
||||
return mi
|
||||
else:
|
||||
logger.debug(("Cache model info "
|
||||
"for class %s.%s miss. "
|
||||
"Loading model instead."), self.module_name,
|
||||
self.class_name)
|
||||
logger.debug(
|
||||
("Cache model info for class %s.%s miss. Loading model instead."),
|
||||
self.module_name,
|
||||
self.class_name,
|
||||
)
|
||||
|
||||
# Performed in another process to avoid initializing CUDA
|
||||
mi = _run_in_subprocess(
|
||||
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
|
||||
logger.debug("Loaded model info for class %s.%s", self.module_name,
|
||||
self.class_name)
|
||||
lambda: _ModelInfo.from_model_cls(self.load_model_cls())
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded model info for class %s.%s", self.module_name, self.class_name
|
||||
)
|
||||
|
||||
# save cache file
|
||||
if module_hash is not None:
|
||||
@@ -536,12 +558,12 @@ def _try_load_model_cls(
|
||||
model: _BaseRegisteredModel,
|
||||
) -> Optional[type[nn.Module]]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.verify_model_arch(model_arch)
|
||||
try:
|
||||
return model.load_model_cls()
|
||||
except Exception:
|
||||
logger.exception("Error in loading model architecture '%s'",
|
||||
model_arch)
|
||||
logger.exception("Error in loading model architecture '%s'", model_arch)
|
||||
return None
|
||||
|
||||
|
||||
@@ -553,8 +575,7 @@ def _try_inspect_model_cls(
|
||||
try:
|
||||
return model.inspect_model_cls()
|
||||
except Exception:
|
||||
logger.exception("Error in inspecting model architecture '%s'",
|
||||
model_arch)
|
||||
logger.exception("Error in inspecting model architecture '%s'", model_arch)
|
||||
return None
|
||||
|
||||
|
||||
@@ -589,8 +610,10 @@ class _ModelRegistry:
|
||||
if model_arch in self.models:
|
||||
logger.warning(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls)
|
||||
"overwritten by the new model class %s.",
|
||||
model_arch,
|
||||
model_cls,
|
||||
)
|
||||
|
||||
if isinstance(model_cls, str):
|
||||
split_str = model_cls.split(":")
|
||||
@@ -602,8 +625,10 @@ class _ModelRegistry:
|
||||
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
|
||||
model = _RegisteredModel.from_model_cls(model_cls)
|
||||
else:
|
||||
msg = ("`model_cls` should be a string or PyTorch model class, "
|
||||
f"not a {type(model_arch)}")
|
||||
msg = (
|
||||
"`model_cls` should be a string or PyTorch model class, "
|
||||
f"not a {type(model_arch)}"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
self.models[model_arch] = model
|
||||
@@ -614,7 +639,8 @@ class _ModelRegistry:
|
||||
if any(arch in all_supported_archs for arch in architectures):
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} failed "
|
||||
"to be inspected. Please check the logs for more details.")
|
||||
"to be inspected. Please check the logs for more details."
|
||||
)
|
||||
|
||||
for arch in architectures:
|
||||
if arch in _PREVIOUSLY_SUPPORTED_MODELS:
|
||||
@@ -624,14 +650,15 @@ class _ModelRegistry:
|
||||
f"Model architecture {arch} was supported in vLLM until "
|
||||
f"v{previous_version}, and is not supported anymore. "
|
||||
"Please use an older version of vLLM if you want to "
|
||||
"use this model architecture.")
|
||||
"use this model architecture."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {all_supported_archs}")
|
||||
f"Supported architectures: {all_supported_archs}"
|
||||
)
|
||||
|
||||
def _try_load_model_cls(self,
|
||||
model_arch: str) -> Optional[type[nn.Module]]:
|
||||
def _try_load_model_cls(self, model_arch: str) -> Optional[type[nn.Module]]:
|
||||
if model_arch not in self.models:
|
||||
return None
|
||||
|
||||
@@ -651,8 +678,9 @@ class _ModelRegistry:
|
||||
if architecture in _TRANSFORMERS_BACKEND_MODELS:
|
||||
return architecture
|
||||
|
||||
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
|
||||
None) or dict()
|
||||
auto_map: dict[str, str] = (
|
||||
getattr(model_config.hf_config, "auto_map", None) or dict()
|
||||
)
|
||||
|
||||
# Make sure that config class is always initialized before model class,
|
||||
# otherwise the model class won't be able to access the config class,
|
||||
@@ -694,7 +722,8 @@ class _ModelRegistry:
|
||||
"registered model in the Transformers library (only "
|
||||
"relevant if the model is meant to be in Transformers) "
|
||||
"and 'AutoModel' is not present in the model config's "
|
||||
"'auto_map' (relevant if the model is custom).")
|
||||
"'auto_map' (relevant if the model is custom)."
|
||||
)
|
||||
|
||||
if not model_module.is_backend_compatible():
|
||||
if model_config.model_impl != "transformers":
|
||||
@@ -702,7 +731,8 @@ class _ModelRegistry:
|
||||
|
||||
raise ValueError(
|
||||
f"The Transformers implementation of {architecture!r} "
|
||||
"is not compatible with vLLM.")
|
||||
"is not compatible with vLLM."
|
||||
)
|
||||
|
||||
return model_config._get_transformers_backend_cls()
|
||||
|
||||
@@ -744,8 +774,7 @@ class _ModelRegistry:
|
||||
|
||||
# Require transformers impl
|
||||
if model_config.model_impl == "transformers":
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
arch = self._try_resolve_transformers(architectures[0], model_config)
|
||||
if arch is not None:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
@@ -755,11 +784,12 @@ class _ModelRegistry:
|
||||
return (model_info, "Terratorch")
|
||||
|
||||
# Fallback to transformers impl (after resolving convert_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == "auto"
|
||||
and getattr(model_config, "convert_type", "none") == "none"):
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
if (
|
||||
all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == "auto"
|
||||
and getattr(model_config, "convert_type", "none") == "none"
|
||||
):
|
||||
arch = self._try_resolve_transformers(architectures[0], model_config)
|
||||
if arch is not None:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
@@ -772,10 +802,11 @@ class _ModelRegistry:
|
||||
return (model_info, arch)
|
||||
|
||||
# Fallback to transformers impl (before resolving runner_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == "auto"):
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
if (
|
||||
all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == "auto"
|
||||
):
|
||||
arch = self._try_resolve_transformers(architectures[0], model_config)
|
||||
if arch is not None:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
@@ -795,8 +826,7 @@ class _ModelRegistry:
|
||||
|
||||
# Require transformers impl
|
||||
if model_config.model_impl == "transformers":
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
arch = self._try_resolve_transformers(architectures[0], model_config)
|
||||
if arch is not None:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
@@ -808,11 +838,12 @@ class _ModelRegistry:
|
||||
return (model_cls, arch)
|
||||
|
||||
# Fallback to transformers impl (after resolving convert_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == "auto"
|
||||
and getattr(model_config, "convert_type", "none") == "none"):
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
if (
|
||||
all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == "auto"
|
||||
and getattr(model_config, "convert_type", "none") == "none"
|
||||
):
|
||||
arch = self._try_resolve_transformers(architectures[0], model_config)
|
||||
if arch is not None:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
@@ -825,10 +856,11 @@ class _ModelRegistry:
|
||||
return (model_cls, arch)
|
||||
|
||||
# Fallback to transformers impl (before resolving runner_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == "auto"):
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
if (
|
||||
all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == "auto"
|
||||
):
|
||||
arch = self._try_resolve_transformers(architectures[0], model_config)
|
||||
if arch is not None:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
@@ -941,14 +973,15 @@ class _ModelRegistry:
|
||||
return not model_cls.supports_v0_only
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry({
|
||||
model_arch:
|
||||
_LazyRegisteredModel(
|
||||
module_name=f"vllm.model_executor.models.{mod_relname}",
|
||||
class_name=cls_name,
|
||||
)
|
||||
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
|
||||
})
|
||||
ModelRegistry = _ModelRegistry(
|
||||
{
|
||||
model_arch: _LazyRegisteredModel(
|
||||
module_name=f"vllm.model_executor.models.{mod_relname}",
|
||||
class_name=cls_name,
|
||||
)
|
||||
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
|
||||
}
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@@ -961,21 +994,23 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
|
||||
|
||||
# `cloudpickle` allows pickling lambda functions directly
|
||||
import cloudpickle
|
||||
|
||||
input_bytes = cloudpickle.dumps((fn, output_filepath))
|
||||
|
||||
# cannot use `sys.executable __file__` here because the script
|
||||
# contains relative imports
|
||||
returned = subprocess.run(_SUBPROCESS_COMMAND,
|
||||
input=input_bytes,
|
||||
capture_output=True)
|
||||
returned = subprocess.run(
|
||||
_SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
|
||||
)
|
||||
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(f"Error raised in subprocess:\n"
|
||||
f"{returned.stderr.decode()}") from e
|
||||
raise RuntimeError(
|
||||
f"Error raised in subprocess:\n{returned.stderr.decode()}"
|
||||
) from e
|
||||
|
||||
with open(output_filepath, "rb") as f:
|
||||
return pickle.load(f)
|
||||
@@ -984,6 +1019,7 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
|
||||
def _run() -> None:
|
||||
# Setup plugins
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
load_general_plugins()
|
||||
|
||||
fn, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
|
||||
Reference in New Issue
Block a user