Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading models."""
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import warnings
|
||||
@@ -17,12 +18,16 @@ from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import (
|
||||
as_embedding_model, as_reward_model, as_seq_cls_model,
|
||||
try_create_mm_pooling_model_cls)
|
||||
from vllm.model_executor.models.interfaces import (SupportsQuant,
|
||||
supports_multimodal)
|
||||
as_embedding_model,
|
||||
as_reward_model,
|
||||
as_seq_cls_model,
|
||||
try_create_mm_pooling_model_cls,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -57,16 +62,16 @@ def initialize_model(
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config,
|
||||
check_compile=True,
|
||||
prefix=prefix):
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly.")
|
||||
msg = (
|
||||
"vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
logger.warning(
|
||||
@@ -87,20 +92,19 @@ def initialize_model(
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config,
|
||||
check_compile=True,
|
||||
prefix=prefix):
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||
target_device: torch.device) -> None:
|
||||
|
||||
def process_weights_after_loading(
|
||||
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
||||
) -> None:
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading)
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading(
|
||||
model, model_config)
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading,
|
||||
)
|
||||
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, QKVCrossParallelLinear):
|
||||
@@ -122,16 +126,16 @@ def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||
# NOTE: This intentionally happens after other modules so we can easily
|
||||
# decompress the weights for MLA.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
if isinstance(module, Attention) and hasattr(
|
||||
module, "process_weights_after_loading"
|
||||
):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module,
|
||||
target_device: torch.device):
|
||||
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
||||
if target_device.type == "cpu":
|
||||
# If target is CPU, no need to move anything
|
||||
yield module
|
||||
@@ -176,8 +180,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
||||
"""Caches the outputs of `_get_model_architecture`."""
|
||||
|
||||
|
||||
def _get_model_architecture(
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
model_cls, arch = model_config.registry.resolve_model_cls(
|
||||
@@ -191,7 +194,9 @@ def _get_model_architecture(
|
||||
logger.warning_once(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.", arch)
|
||||
"performance may not be optimal.",
|
||||
arch,
|
||||
)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type != "none" and supports_multimodal(model_cls):
|
||||
@@ -220,16 +225,17 @@ def _get_model_architecture(
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
key = hash((
|
||||
model_config.model,
|
||||
model_config.convert_type,
|
||||
model_config.runner_type,
|
||||
model_config.trust_remote_code,
|
||||
model_config.model_impl,
|
||||
tuple(getattr(model_config.hf_config, "architectures", [])),
|
||||
))
|
||||
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
key = hash(
|
||||
(
|
||||
model_config.model,
|
||||
model_config.convert_type,
|
||||
model_config.runner_type,
|
||||
model_config.trust_remote_code,
|
||||
model_config.model_impl,
|
||||
tuple(getattr(model_config.hf_config, "architectures", [])),
|
||||
)
|
||||
)
|
||||
if key in _MODEL_ARCH_BY_HASH:
|
||||
return _MODEL_ARCH_BY_HASH[key]
|
||||
|
||||
@@ -253,9 +259,9 @@ class ParamMapping:
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str,
|
||||
int]] = field(default_factory=dict)
|
||||
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for packed_name, sub_params in self.packed_mapping.items():
|
||||
@@ -268,16 +274,16 @@ class ParamMapping:
|
||||
index,
|
||||
)
|
||||
|
||||
def get_sub_modules(self,
|
||||
module_name: str) -> Optional[tuple[str, list[str]]]:
|
||||
def get_sub_modules(self, module_name: str) -> Optional[tuple[str, list[str]]]:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
|
||||
def configure_quant_config(quant_config: QuantizationConfig,
|
||||
model_class: type[nn.Module]):
|
||||
def configure_quant_config(
|
||||
quant_config: QuantizationConfig, model_class: type[nn.Module]
|
||||
):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Reference in New Issue
Block a user